1use crate::query::plan::{
18 AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19 LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
20 VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31 pub lower_bound: f64,
33 pub upper_bound: f64,
35 pub frequency: u64,
37 pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42 #[must_use]
44 pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45 Self {
46 lower_bound,
47 upper_bound,
48 frequency,
49 distinct_count,
50 }
51 }
52
53 #[must_use]
55 pub fn width(&self) -> f64 {
56 self.upper_bound - self.lower_bound
57 }
58
59 #[must_use]
61 pub fn contains(&self, value: f64) -> bool {
62 value >= self.lower_bound && value < self.upper_bound
63 }
64
65 #[must_use]
67 pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68 let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69 let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71 let bucket_width = self.width();
72 if bucket_width <= 0.0 {
73 return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74 1.0
75 } else {
76 0.0
77 };
78 }
79
80 let overlap = (effective_upper - effective_lower).max(0.0);
81 (overlap / bucket_width).min(1.0)
82 }
83}
84
85#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106 buckets: Vec<HistogramBucket>,
108 total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113 #[must_use]
115 pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116 let total_rows = buckets.iter().map(|b| b.frequency).sum();
117 Self {
118 buckets,
119 total_rows,
120 }
121 }
122
123 #[must_use]
132 pub fn build(values: &[f64], num_buckets: usize) -> Self {
133 if values.is_empty() || num_buckets == 0 {
134 return Self {
135 buckets: Vec::new(),
136 total_rows: 0,
137 };
138 }
139
140 let num_buckets = num_buckets.min(values.len());
141 let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142 let mut buckets = Vec::with_capacity(num_buckets);
143
144 let mut start_idx = 0;
145 while start_idx < values.len() {
146 let end_idx = (start_idx + rows_per_bucket).min(values.len());
147 let lower_bound = values[start_idx];
148 let upper_bound = if end_idx < values.len() {
149 values[end_idx]
150 } else {
151 values[end_idx - 1] + 1.0
153 };
154
155 let bucket_values = &values[start_idx..end_idx];
157 let distinct_count = count_distinct(bucket_values);
158
159 buckets.push(HistogramBucket::new(
160 lower_bound,
161 upper_bound,
162 (end_idx - start_idx) as u64,
163 distinct_count,
164 ));
165
166 start_idx = end_idx;
167 }
168
169 Self::new(buckets)
170 }
171
172 #[must_use]
174 pub fn num_buckets(&self) -> usize {
175 self.buckets.len()
176 }
177
178 #[must_use]
180 pub fn total_rows(&self) -> u64 {
181 self.total_rows
182 }
183
184 #[must_use]
186 pub fn buckets(&self) -> &[HistogramBucket] {
187 &self.buckets
188 }
189
190 #[must_use]
199 pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200 if self.buckets.is_empty() || self.total_rows == 0 {
201 return 0.33; }
203
204 let mut matching_rows = 0.0;
205
206 for bucket in &self.buckets {
207 let bucket_lower = bucket.lower_bound;
209 let bucket_upper = bucket.upper_bound;
210
211 if let Some(l) = lower
213 && bucket_upper <= l
214 {
215 continue;
216 }
217 if let Some(u) = upper
218 && bucket_lower >= u
219 {
220 continue;
221 }
222
223 let overlap = bucket.overlap_fraction(lower, upper);
225 matching_rows += overlap * bucket.frequency as f64;
226 }
227
228 (matching_rows / self.total_rows as f64).min(1.0).max(0.0)
229 }
230
231 #[must_use]
235 pub fn equality_selectivity(&self, value: f64) -> f64 {
236 if self.buckets.is_empty() || self.total_rows == 0 {
237 return 0.01; }
239
240 for bucket in &self.buckets {
242 if bucket.contains(value) {
243 if bucket.distinct_count > 0 {
245 return (bucket.frequency as f64
246 / bucket.distinct_count as f64
247 / self.total_rows as f64)
248 .min(1.0);
249 }
250 }
251 }
252
253 0.001
255 }
256
257 #[must_use]
259 pub fn min_value(&self) -> Option<f64> {
260 self.buckets.first().map(|b| b.lower_bound)
261 }
262
263 #[must_use]
265 pub fn max_value(&self) -> Option<f64> {
266 self.buckets.last().map(|b| b.upper_bound)
267 }
268}
269
270fn count_distinct(sorted_values: &[f64]) -> u64 {
272 if sorted_values.is_empty() {
273 return 0;
274 }
275
276 let mut count = 1u64;
277 let mut prev = sorted_values[0];
278
279 for &val in &sorted_values[1..] {
280 if (val - prev).abs() > f64::EPSILON {
281 count += 1;
282 prev = val;
283 }
284 }
285
286 count
287}
288
289#[derive(Debug, Clone)]
291pub struct TableStats {
292 pub row_count: u64,
294 pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299 #[must_use]
301 pub fn new(row_count: u64) -> Self {
302 Self {
303 row_count,
304 columns: HashMap::new(),
305 }
306 }
307
308 pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310 self.columns.insert(name.to_string(), stats);
311 self
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct ColumnStats {
318 pub distinct_count: u64,
320 pub null_count: u64,
322 pub min_value: Option<f64>,
324 pub max_value: Option<f64>,
326 pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331 #[must_use]
333 pub fn new(distinct_count: u64) -> Self {
334 Self {
335 distinct_count,
336 null_count: 0,
337 min_value: None,
338 max_value: None,
339 histogram: None,
340 }
341 }
342
343 #[must_use]
345 pub fn with_nulls(mut self, null_count: u64) -> Self {
346 self.null_count = null_count;
347 self
348 }
349
350 #[must_use]
352 pub fn with_range(mut self, min: f64, max: f64) -> Self {
353 self.min_value = Some(min);
354 self.max_value = Some(max);
355 self
356 }
357
358 #[must_use]
360 pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361 self.histogram = Some(histogram);
362 self
363 }
364
365 #[must_use]
373 pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374 if values.is_empty() {
375 return Self::new(0);
376 }
377
378 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381 let min = values.first().copied();
382 let max = values.last().copied();
383 let distinct_count = count_distinct(&values);
384 let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386 Self {
387 distinct_count,
388 null_count: 0,
389 min_value: min,
390 max_value: max,
391 histogram: Some(histogram),
392 }
393 }
394}
395
396#[derive(Debug, Clone)]
402pub struct SelectivityConfig {
403 pub default: f64,
405 pub equality: f64,
407 pub inequality: f64,
409 pub range: f64,
411 pub string_ops: f64,
413 pub membership: f64,
415 pub is_null: f64,
417 pub is_not_null: f64,
419 pub distinct_fraction: f64,
421}
422
423impl SelectivityConfig {
424 #[must_use]
426 pub fn new() -> Self {
427 Self {
428 default: 0.1,
429 equality: 0.01,
430 inequality: 0.99,
431 range: 0.33,
432 string_ops: 0.1,
433 membership: 0.1,
434 is_null: 0.05,
435 is_not_null: 0.95,
436 distinct_fraction: 0.5,
437 }
438 }
439}
440
441impl Default for SelectivityConfig {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447#[derive(Debug, Clone)]
449pub struct EstimationEntry {
450 pub operator: String,
452 pub estimated: f64,
454 pub actual: f64,
456}
457
458impl EstimationEntry {
459 #[must_use]
465 pub fn error_ratio(&self) -> f64 {
466 if self.estimated.abs() < f64::EPSILON {
467 if self.actual.abs() < f64::EPSILON {
468 1.0
469 } else {
470 f64::INFINITY
471 }
472 } else {
473 self.actual / self.estimated
474 }
475 }
476}
477
478#[derive(Debug, Clone, Default)]
485pub struct EstimationLog {
486 entries: Vec<EstimationEntry>,
488 replan_threshold: f64,
492}
493
494impl EstimationLog {
495 #[must_use]
497 pub fn new(replan_threshold: f64) -> Self {
498 Self {
499 entries: Vec::new(),
500 replan_threshold,
501 }
502 }
503
504 pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
506 self.entries.push(EstimationEntry {
507 operator: operator.into(),
508 estimated,
509 actual,
510 });
511 }
512
513 #[must_use]
515 pub fn entries(&self) -> &[EstimationEntry] {
516 &self.entries
517 }
518
519 #[must_use]
522 pub fn should_replan(&self) -> bool {
523 self.entries.iter().any(|e| {
524 let ratio = e.error_ratio();
525 ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
526 })
527 }
528
529 #[must_use]
531 pub fn max_error_ratio(&self) -> f64 {
532 self.entries
533 .iter()
534 .map(|e| {
535 let r = e.error_ratio();
536 if r < 1.0 { 1.0 / r } else { r }
538 })
539 .fold(1.0_f64, f64::max)
540 }
541
542 pub fn clear(&mut self) {
544 self.entries.clear();
545 }
546}
547
548pub struct CardinalityEstimator {
550 table_stats: HashMap<String, TableStats>,
552 default_row_count: u64,
554 default_selectivity: f64,
556 avg_fanout: f64,
558 selectivity_config: SelectivityConfig,
560}
561
562impl CardinalityEstimator {
563 #[must_use]
565 pub fn new() -> Self {
566 let config = SelectivityConfig::new();
567 Self {
568 table_stats: HashMap::new(),
569 default_row_count: 1000,
570 default_selectivity: config.default,
571 avg_fanout: 10.0,
572 selectivity_config: config,
573 }
574 }
575
576 #[must_use]
578 pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
579 Self {
580 table_stats: HashMap::new(),
581 default_row_count: 1000,
582 default_selectivity: config.default,
583 avg_fanout: 10.0,
584 selectivity_config: config,
585 }
586 }
587
588 #[must_use]
590 pub fn selectivity_config(&self) -> &SelectivityConfig {
591 &self.selectivity_config
592 }
593
594 #[must_use]
596 pub fn create_estimation_log() -> EstimationLog {
597 EstimationLog::new(10.0)
598 }
599
600 #[must_use]
606 pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
607 let mut estimator = Self::new();
608
609 if stats.total_nodes > 0 {
611 estimator.default_row_count = stats.total_nodes;
612 }
613
614 for (label, label_stats) in &stats.labels {
616 let mut table_stats = TableStats::new(label_stats.node_count);
617
618 for (prop, col_stats) in &label_stats.properties {
620 let optimizer_col =
621 ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
622 table_stats = table_stats.with_column(prop, optimizer_col);
623 }
624
625 estimator.add_table_stats(label, table_stats);
626 }
627
628 if !stats.edge_types.is_empty() {
630 let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
631 estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
632 } else if stats.total_nodes > 0 {
633 estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
634 }
635
636 if estimator.avg_fanout < 1.0 {
638 estimator.avg_fanout = 1.0;
639 }
640
641 estimator
642 }
643
644 pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
646 self.table_stats.insert(name.to_string(), stats);
647 }
648
649 pub fn set_avg_fanout(&mut self, fanout: f64) {
651 self.avg_fanout = fanout;
652 }
653
654 #[must_use]
656 pub fn estimate(&self, op: &LogicalOperator) -> f64 {
657 match op {
658 LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
659 LogicalOperator::Filter(filter) => self.estimate_filter(filter),
660 LogicalOperator::Project(project) => self.estimate_project(project),
661 LogicalOperator::Expand(expand) => self.estimate_expand(expand),
662 LogicalOperator::Join(join) => self.estimate_join(join),
663 LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
664 LogicalOperator::Sort(sort) => self.estimate_sort(sort),
665 LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
666 LogicalOperator::Limit(limit) => self.estimate_limit(limit),
667 LogicalOperator::Skip(skip) => self.estimate_skip(skip),
668 LogicalOperator::Return(ret) => self.estimate(&ret.input),
669 LogicalOperator::Empty => 0.0,
670 LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
671 LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
672 _ => self.default_row_count as f64,
673 }
674 }
675
676 fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
678 if let Some(label) = &scan.label
679 && let Some(stats) = self.table_stats.get(label)
680 {
681 return stats.row_count as f64;
682 }
683 self.default_row_count as f64
685 }
686
687 fn estimate_filter(&self, filter: &FilterOp) -> f64 {
689 let input_cardinality = self.estimate(&filter.input);
690 let selectivity = self.estimate_selectivity(&filter.predicate);
691 (input_cardinality * selectivity).max(1.0)
692 }
693
694 fn estimate_project(&self, project: &ProjectOp) -> f64 {
696 self.estimate(&project.input)
697 }
698
699 fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
701 let input_cardinality = self.estimate(&expand.input);
702
703 let fanout = if !expand.edge_types.is_empty() {
705 self.avg_fanout * 0.5
707 } else {
708 self.avg_fanout
709 };
710
711 let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
713 let min = expand.min_hops as f64;
714 let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
715 (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
717 } else {
718 fanout
719 };
720
721 (input_cardinality * path_multiplier).max(1.0)
722 }
723
724 fn estimate_join(&self, join: &JoinOp) -> f64 {
726 let left_card = self.estimate(&join.left);
727 let right_card = self.estimate(&join.right);
728
729 match join.join_type {
730 JoinType::Cross => left_card * right_card,
731 JoinType::Inner => {
732 let selectivity = if join.conditions.is_empty() {
734 1.0 } else {
736 0.1_f64.powi(join.conditions.len() as i32)
738 };
739 (left_card * right_card * selectivity).max(1.0)
740 }
741 JoinType::Left => {
742 let inner_card = self.estimate_join(&JoinOp {
744 left: join.left.clone(),
745 right: join.right.clone(),
746 join_type: JoinType::Inner,
747 conditions: join.conditions.clone(),
748 });
749 inner_card.max(left_card)
750 }
751 JoinType::Right => {
752 let inner_card = self.estimate_join(&JoinOp {
754 left: join.left.clone(),
755 right: join.right.clone(),
756 join_type: JoinType::Inner,
757 conditions: join.conditions.clone(),
758 });
759 inner_card.max(right_card)
760 }
761 JoinType::Full => {
762 let inner_card = self.estimate_join(&JoinOp {
764 left: join.left.clone(),
765 right: join.right.clone(),
766 join_type: JoinType::Inner,
767 conditions: join.conditions.clone(),
768 });
769 inner_card.max(left_card.max(right_card))
770 }
771 JoinType::Semi => {
772 (left_card * self.default_selectivity).max(1.0)
774 }
775 JoinType::Anti => {
776 (left_card * (1.0 - self.default_selectivity)).max(1.0)
778 }
779 }
780 }
781
782 fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
784 let input_cardinality = self.estimate(&agg.input);
785
786 if agg.group_by.is_empty() {
787 1.0
789 } else {
790 let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
793 (input_cardinality / group_reduction).max(1.0)
794 }
795 }
796
797 fn estimate_sort(&self, sort: &SortOp) -> f64 {
799 self.estimate(&sort.input)
800 }
801
802 fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
804 let input_cardinality = self.estimate(&distinct.input);
805 (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
806 }
807
808 fn estimate_limit(&self, limit: &LimitOp) -> f64 {
810 let input_cardinality = self.estimate(&limit.input);
811 (limit.count as f64).min(input_cardinality)
812 }
813
814 fn estimate_skip(&self, skip: &SkipOp) -> f64 {
816 let input_cardinality = self.estimate(&skip.input);
817 (input_cardinality - skip.count as f64).max(0.0)
818 }
819
820 fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
825 let base_k = scan.k as f64;
826
827 let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
829 0.7
831 } else {
832 1.0
833 };
834
835 (base_k * selectivity).max(1.0)
836 }
837
838 fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
842 let input_cardinality = self.estimate(&join.input);
843 let k = join.k as f64;
844
845 let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
847 0.7
848 } else {
849 1.0
850 };
851
852 (input_cardinality * k * selectivity).max(1.0)
853 }
854
855 fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
857 match expr {
858 LogicalExpression::Binary { left, op, right } => {
859 self.estimate_binary_selectivity(left, *op, right)
860 }
861 LogicalExpression::Unary { op, operand } => {
862 self.estimate_unary_selectivity(*op, operand)
863 }
864 LogicalExpression::Literal(value) => {
865 if let grafeo_common::types::Value::Bool(b) = value {
867 if *b { 1.0 } else { 0.0 }
868 } else {
869 self.default_selectivity
870 }
871 }
872 _ => self.default_selectivity,
873 }
874 }
875
876 fn estimate_binary_selectivity(
878 &self,
879 left: &LogicalExpression,
880 op: BinaryOp,
881 right: &LogicalExpression,
882 ) -> f64 {
883 match op {
884 BinaryOp::Eq => {
886 if let Some(selectivity) = self.try_equality_selectivity(left, right) {
887 return selectivity;
888 }
889 self.selectivity_config.equality
890 }
891 BinaryOp::Ne => self.selectivity_config.inequality,
893 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
895 if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
896 return selectivity;
897 }
898 self.selectivity_config.range
899 }
900 BinaryOp::And => {
902 let left_sel = self.estimate_selectivity(left);
903 let right_sel = self.estimate_selectivity(right);
904 left_sel * right_sel
906 }
907 BinaryOp::Or => {
908 let left_sel = self.estimate_selectivity(left);
909 let right_sel = self.estimate_selectivity(right);
910 (left_sel + right_sel - left_sel * right_sel).min(1.0)
913 }
914 BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
916 self.selectivity_config.string_ops
917 }
918 BinaryOp::In => self.selectivity_config.membership,
920 _ => self.default_selectivity,
922 }
923 }
924
925 fn try_equality_selectivity(
927 &self,
928 left: &LogicalExpression,
929 right: &LogicalExpression,
930 ) -> Option<f64> {
931 let (label, column, value) = self.extract_column_and_value(left, right)?;
933
934 let stats = self.get_column_stats(&label, &column)?;
936
937 if let Some(ref histogram) = stats.histogram {
939 return Some(histogram.equality_selectivity(value));
940 }
941
942 if stats.distinct_count > 0 {
944 return Some(1.0 / stats.distinct_count as f64);
945 }
946
947 None
948 }
949
950 fn try_range_selectivity(
952 &self,
953 left: &LogicalExpression,
954 op: BinaryOp,
955 right: &LogicalExpression,
956 ) -> Option<f64> {
957 let (label, column, value) = self.extract_column_and_value(left, right)?;
959
960 let stats = self.get_column_stats(&label, &column)?;
962
963 let (lower, upper) = match op {
965 BinaryOp::Lt => (None, Some(value)),
966 BinaryOp::Le => (None, Some(value + f64::EPSILON)),
967 BinaryOp::Gt => (Some(value + f64::EPSILON), None),
968 BinaryOp::Ge => (Some(value), None),
969 _ => return None,
970 };
971
972 if let Some(ref histogram) = stats.histogram {
974 return Some(histogram.range_selectivity(lower, upper));
975 }
976
977 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
979 let range = max - min;
980 if range <= 0.0 {
981 return Some(1.0);
982 }
983
984 let effective_lower = lower.unwrap_or(min).max(min);
985 let effective_upper = upper.unwrap_or(max).min(max);
986 let overlap = (effective_upper - effective_lower).max(0.0);
987 return Some((overlap / range).min(1.0).max(0.0));
988 }
989
990 None
991 }
992
993 fn extract_column_and_value(
998 &self,
999 left: &LogicalExpression,
1000 right: &LogicalExpression,
1001 ) -> Option<(String, String, f64)> {
1002 if let Some(result) = self.try_extract_property_literal(left, right) {
1004 return Some(result);
1005 }
1006
1007 self.try_extract_property_literal(right, left)
1009 }
1010
1011 fn try_extract_property_literal(
1013 &self,
1014 property_expr: &LogicalExpression,
1015 literal_expr: &LogicalExpression,
1016 ) -> Option<(String, String, f64)> {
1017 let (variable, property) = match property_expr {
1019 LogicalExpression::Property { variable, property } => {
1020 (variable.clone(), property.clone())
1021 }
1022 _ => return None,
1023 };
1024
1025 let value = match literal_expr {
1027 LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1028 LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1029 _ => return None,
1030 };
1031
1032 for label in self.table_stats.keys() {
1036 if let Some(stats) = self.table_stats.get(label)
1037 && stats.columns.contains_key(&property)
1038 {
1039 return Some((label.clone(), property, value));
1040 }
1041 }
1042
1043 Some((variable, property, value))
1045 }
1046
1047 fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1049 match op {
1050 UnaryOp::Not => 1.0 - self.default_selectivity,
1051 UnaryOp::IsNull => self.selectivity_config.is_null,
1052 UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1053 UnaryOp::Neg => 1.0, }
1055 }
1056
1057 fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1059 self.table_stats.get(label)?.columns.get(column)
1060 }
1061}
1062
1063impl Default for CardinalityEstimator {
1064 fn default() -> Self {
1065 Self::new()
1066 }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071 use super::*;
1072 use crate::query::plan::{
1073 DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1074 ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1075 };
1076 use grafeo_common::types::Value;
1077
1078 #[test]
1079 fn test_node_scan_with_stats() {
1080 let mut estimator = CardinalityEstimator::new();
1081 estimator.add_table_stats("Person", TableStats::new(5000));
1082
1083 let scan = LogicalOperator::NodeScan(NodeScanOp {
1084 variable: "n".to_string(),
1085 label: Some("Person".to_string()),
1086 input: None,
1087 });
1088
1089 let cardinality = estimator.estimate(&scan);
1090 assert!((cardinality - 5000.0).abs() < 0.001);
1091 }
1092
1093 #[test]
1094 fn test_filter_reduces_cardinality() {
1095 let mut estimator = CardinalityEstimator::new();
1096 estimator.add_table_stats("Person", TableStats::new(1000));
1097
1098 let filter = LogicalOperator::Filter(FilterOp {
1099 predicate: LogicalExpression::Binary {
1100 left: Box::new(LogicalExpression::Property {
1101 variable: "n".to_string(),
1102 property: "age".to_string(),
1103 }),
1104 op: BinaryOp::Eq,
1105 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1106 },
1107 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1108 variable: "n".to_string(),
1109 label: Some("Person".to_string()),
1110 input: None,
1111 })),
1112 });
1113
1114 let cardinality = estimator.estimate(&filter);
1115 assert!(cardinality < 1000.0);
1117 assert!(cardinality >= 1.0);
1118 }
1119
1120 #[test]
1121 fn test_join_cardinality() {
1122 let mut estimator = CardinalityEstimator::new();
1123 estimator.add_table_stats("Person", TableStats::new(1000));
1124 estimator.add_table_stats("Company", TableStats::new(100));
1125
1126 let join = LogicalOperator::Join(JoinOp {
1127 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1128 variable: "p".to_string(),
1129 label: Some("Person".to_string()),
1130 input: None,
1131 })),
1132 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1133 variable: "c".to_string(),
1134 label: Some("Company".to_string()),
1135 input: None,
1136 })),
1137 join_type: JoinType::Inner,
1138 conditions: vec![JoinCondition {
1139 left: LogicalExpression::Property {
1140 variable: "p".to_string(),
1141 property: "company_id".to_string(),
1142 },
1143 right: LogicalExpression::Property {
1144 variable: "c".to_string(),
1145 property: "id".to_string(),
1146 },
1147 }],
1148 });
1149
1150 let cardinality = estimator.estimate(&join);
1151 assert!(cardinality < 1000.0 * 100.0);
1153 }
1154
1155 #[test]
1156 fn test_limit_caps_cardinality() {
1157 let mut estimator = CardinalityEstimator::new();
1158 estimator.add_table_stats("Person", TableStats::new(1000));
1159
1160 let limit = LogicalOperator::Limit(LimitOp {
1161 count: 10,
1162 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1163 variable: "n".to_string(),
1164 label: Some("Person".to_string()),
1165 input: None,
1166 })),
1167 });
1168
1169 let cardinality = estimator.estimate(&limit);
1170 assert!((cardinality - 10.0).abs() < 0.001);
1171 }
1172
1173 #[test]
1174 fn test_aggregate_reduces_cardinality() {
1175 let mut estimator = CardinalityEstimator::new();
1176 estimator.add_table_stats("Person", TableStats::new(1000));
1177
1178 let global_agg = LogicalOperator::Aggregate(AggregateOp {
1180 group_by: vec![],
1181 aggregates: vec![],
1182 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1183 variable: "n".to_string(),
1184 label: Some("Person".to_string()),
1185 input: None,
1186 })),
1187 having: None,
1188 });
1189
1190 let cardinality = estimator.estimate(&global_agg);
1191 assert!((cardinality - 1.0).abs() < 0.001);
1192
1193 let group_agg = LogicalOperator::Aggregate(AggregateOp {
1195 group_by: vec![LogicalExpression::Property {
1196 variable: "n".to_string(),
1197 property: "city".to_string(),
1198 }],
1199 aggregates: vec![],
1200 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1201 variable: "n".to_string(),
1202 label: Some("Person".to_string()),
1203 input: None,
1204 })),
1205 having: None,
1206 });
1207
1208 let cardinality = estimator.estimate(&group_agg);
1209 assert!(cardinality < 1000.0);
1211 }
1212
1213 #[test]
1214 fn test_node_scan_without_stats() {
1215 let estimator = CardinalityEstimator::new();
1216
1217 let scan = LogicalOperator::NodeScan(NodeScanOp {
1218 variable: "n".to_string(),
1219 label: Some("Unknown".to_string()),
1220 input: None,
1221 });
1222
1223 let cardinality = estimator.estimate(&scan);
1224 assert!((cardinality - 1000.0).abs() < 0.001);
1226 }
1227
1228 #[test]
1229 fn test_node_scan_no_label() {
1230 let estimator = CardinalityEstimator::new();
1231
1232 let scan = LogicalOperator::NodeScan(NodeScanOp {
1233 variable: "n".to_string(),
1234 label: None,
1235 input: None,
1236 });
1237
1238 let cardinality = estimator.estimate(&scan);
1239 assert!((cardinality - 1000.0).abs() < 0.001);
1241 }
1242
1243 #[test]
1244 fn test_filter_inequality_selectivity() {
1245 let mut estimator = CardinalityEstimator::new();
1246 estimator.add_table_stats("Person", TableStats::new(1000));
1247
1248 let filter = LogicalOperator::Filter(FilterOp {
1249 predicate: LogicalExpression::Binary {
1250 left: Box::new(LogicalExpression::Property {
1251 variable: "n".to_string(),
1252 property: "age".to_string(),
1253 }),
1254 op: BinaryOp::Ne,
1255 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1256 },
1257 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1258 variable: "n".to_string(),
1259 label: Some("Person".to_string()),
1260 input: None,
1261 })),
1262 });
1263
1264 let cardinality = estimator.estimate(&filter);
1265 assert!(cardinality > 900.0);
1267 }
1268
1269 #[test]
1270 fn test_filter_range_selectivity() {
1271 let mut estimator = CardinalityEstimator::new();
1272 estimator.add_table_stats("Person", TableStats::new(1000));
1273
1274 let filter = LogicalOperator::Filter(FilterOp {
1275 predicate: LogicalExpression::Binary {
1276 left: Box::new(LogicalExpression::Property {
1277 variable: "n".to_string(),
1278 property: "age".to_string(),
1279 }),
1280 op: BinaryOp::Gt,
1281 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1282 },
1283 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1284 variable: "n".to_string(),
1285 label: Some("Person".to_string()),
1286 input: None,
1287 })),
1288 });
1289
1290 let cardinality = estimator.estimate(&filter);
1291 assert!(cardinality < 500.0);
1293 assert!(cardinality > 100.0);
1294 }
1295
1296 #[test]
1297 fn test_filter_and_selectivity() {
1298 let mut estimator = CardinalityEstimator::new();
1299 estimator.add_table_stats("Person", TableStats::new(1000));
1300
1301 let filter = LogicalOperator::Filter(FilterOp {
1304 predicate: LogicalExpression::Binary {
1305 left: Box::new(LogicalExpression::Binary {
1306 left: Box::new(LogicalExpression::Property {
1307 variable: "n".to_string(),
1308 property: "city".to_string(),
1309 }),
1310 op: BinaryOp::Eq,
1311 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1312 }),
1313 op: BinaryOp::And,
1314 right: Box::new(LogicalExpression::Binary {
1315 left: Box::new(LogicalExpression::Property {
1316 variable: "n".to_string(),
1317 property: "age".to_string(),
1318 }),
1319 op: BinaryOp::Eq,
1320 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1321 }),
1322 },
1323 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1324 variable: "n".to_string(),
1325 label: Some("Person".to_string()),
1326 input: None,
1327 })),
1328 });
1329
1330 let cardinality = estimator.estimate(&filter);
1331 assert!(cardinality < 100.0);
1334 assert!(cardinality >= 1.0);
1335 }
1336
1337 #[test]
1338 fn test_filter_or_selectivity() {
1339 let mut estimator = CardinalityEstimator::new();
1340 estimator.add_table_stats("Person", TableStats::new(1000));
1341
1342 let filter = LogicalOperator::Filter(FilterOp {
1346 predicate: LogicalExpression::Binary {
1347 left: Box::new(LogicalExpression::Binary {
1348 left: Box::new(LogicalExpression::Property {
1349 variable: "n".to_string(),
1350 property: "city".to_string(),
1351 }),
1352 op: BinaryOp::Eq,
1353 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1354 }),
1355 op: BinaryOp::Or,
1356 right: Box::new(LogicalExpression::Binary {
1357 left: Box::new(LogicalExpression::Property {
1358 variable: "n".to_string(),
1359 property: "city".to_string(),
1360 }),
1361 op: BinaryOp::Eq,
1362 right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1363 }),
1364 },
1365 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1366 variable: "n".to_string(),
1367 label: Some("Person".to_string()),
1368 input: None,
1369 })),
1370 });
1371
1372 let cardinality = estimator.estimate(&filter);
1373 assert!(cardinality < 100.0);
1375 assert!(cardinality >= 1.0);
1376 }
1377
1378 #[test]
1379 fn test_filter_literal_true() {
1380 let mut estimator = CardinalityEstimator::new();
1381 estimator.add_table_stats("Person", TableStats::new(1000));
1382
1383 let filter = LogicalOperator::Filter(FilterOp {
1384 predicate: LogicalExpression::Literal(Value::Bool(true)),
1385 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1386 variable: "n".to_string(),
1387 label: Some("Person".to_string()),
1388 input: None,
1389 })),
1390 });
1391
1392 let cardinality = estimator.estimate(&filter);
1393 assert!((cardinality - 1000.0).abs() < 0.001);
1395 }
1396
1397 #[test]
1398 fn test_filter_literal_false() {
1399 let mut estimator = CardinalityEstimator::new();
1400 estimator.add_table_stats("Person", TableStats::new(1000));
1401
1402 let filter = LogicalOperator::Filter(FilterOp {
1403 predicate: LogicalExpression::Literal(Value::Bool(false)),
1404 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1405 variable: "n".to_string(),
1406 label: Some("Person".to_string()),
1407 input: None,
1408 })),
1409 });
1410
1411 let cardinality = estimator.estimate(&filter);
1412 assert!((cardinality - 1.0).abs() < 0.001);
1414 }
1415
1416 #[test]
1417 fn test_unary_not_selectivity() {
1418 let mut estimator = CardinalityEstimator::new();
1419 estimator.add_table_stats("Person", TableStats::new(1000));
1420
1421 let filter = LogicalOperator::Filter(FilterOp {
1422 predicate: LogicalExpression::Unary {
1423 op: UnaryOp::Not,
1424 operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1425 },
1426 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1427 variable: "n".to_string(),
1428 label: Some("Person".to_string()),
1429 input: None,
1430 })),
1431 });
1432
1433 let cardinality = estimator.estimate(&filter);
1434 assert!(cardinality < 1000.0);
1436 }
1437
1438 #[test]
1439 fn test_unary_is_null_selectivity() {
1440 let mut estimator = CardinalityEstimator::new();
1441 estimator.add_table_stats("Person", TableStats::new(1000));
1442
1443 let filter = LogicalOperator::Filter(FilterOp {
1444 predicate: LogicalExpression::Unary {
1445 op: UnaryOp::IsNull,
1446 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1447 },
1448 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1449 variable: "n".to_string(),
1450 label: Some("Person".to_string()),
1451 input: None,
1452 })),
1453 });
1454
1455 let cardinality = estimator.estimate(&filter);
1456 assert!(cardinality < 100.0);
1458 }
1459
1460 #[test]
1461 fn test_expand_cardinality() {
1462 let mut estimator = CardinalityEstimator::new();
1463 estimator.add_table_stats("Person", TableStats::new(100));
1464
1465 let expand = LogicalOperator::Expand(ExpandOp {
1466 from_variable: "a".to_string(),
1467 to_variable: "b".to_string(),
1468 edge_variable: None,
1469 direction: ExpandDirection::Outgoing,
1470 edge_types: vec![],
1471 min_hops: 1,
1472 max_hops: Some(1),
1473 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1474 variable: "a".to_string(),
1475 label: Some("Person".to_string()),
1476 input: None,
1477 })),
1478 path_alias: None,
1479 path_mode: PathMode::Walk,
1480 });
1481
1482 let cardinality = estimator.estimate(&expand);
1483 assert!(cardinality > 100.0);
1485 }
1486
1487 #[test]
1488 fn test_expand_with_edge_type_filter() {
1489 let mut estimator = CardinalityEstimator::new();
1490 estimator.add_table_stats("Person", TableStats::new(100));
1491
1492 let expand = LogicalOperator::Expand(ExpandOp {
1493 from_variable: "a".to_string(),
1494 to_variable: "b".to_string(),
1495 edge_variable: None,
1496 direction: ExpandDirection::Outgoing,
1497 edge_types: vec!["KNOWS".to_string()],
1498 min_hops: 1,
1499 max_hops: Some(1),
1500 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1501 variable: "a".to_string(),
1502 label: Some("Person".to_string()),
1503 input: None,
1504 })),
1505 path_alias: None,
1506 path_mode: PathMode::Walk,
1507 });
1508
1509 let cardinality = estimator.estimate(&expand);
1510 assert!(cardinality > 100.0);
1512 }
1513
1514 #[test]
1515 fn test_expand_variable_length() {
1516 let mut estimator = CardinalityEstimator::new();
1517 estimator.add_table_stats("Person", TableStats::new(100));
1518
1519 let expand = LogicalOperator::Expand(ExpandOp {
1520 from_variable: "a".to_string(),
1521 to_variable: "b".to_string(),
1522 edge_variable: None,
1523 direction: ExpandDirection::Outgoing,
1524 edge_types: vec![],
1525 min_hops: 1,
1526 max_hops: Some(3),
1527 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1528 variable: "a".to_string(),
1529 label: Some("Person".to_string()),
1530 input: None,
1531 })),
1532 path_alias: None,
1533 path_mode: PathMode::Walk,
1534 });
1535
1536 let cardinality = estimator.estimate(&expand);
1537 assert!(cardinality > 500.0);
1539 }
1540
1541 #[test]
1542 fn test_join_cross_product() {
1543 let mut estimator = CardinalityEstimator::new();
1544 estimator.add_table_stats("Person", TableStats::new(100));
1545 estimator.add_table_stats("Company", TableStats::new(50));
1546
1547 let join = LogicalOperator::Join(JoinOp {
1548 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1549 variable: "p".to_string(),
1550 label: Some("Person".to_string()),
1551 input: None,
1552 })),
1553 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1554 variable: "c".to_string(),
1555 label: Some("Company".to_string()),
1556 input: None,
1557 })),
1558 join_type: JoinType::Cross,
1559 conditions: vec![],
1560 });
1561
1562 let cardinality = estimator.estimate(&join);
1563 assert!((cardinality - 5000.0).abs() < 0.001);
1565 }
1566
1567 #[test]
1568 fn test_join_left_outer() {
1569 let mut estimator = CardinalityEstimator::new();
1570 estimator.add_table_stats("Person", TableStats::new(1000));
1571 estimator.add_table_stats("Company", TableStats::new(10));
1572
1573 let join = LogicalOperator::Join(JoinOp {
1574 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1575 variable: "p".to_string(),
1576 label: Some("Person".to_string()),
1577 input: None,
1578 })),
1579 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1580 variable: "c".to_string(),
1581 label: Some("Company".to_string()),
1582 input: None,
1583 })),
1584 join_type: JoinType::Left,
1585 conditions: vec![JoinCondition {
1586 left: LogicalExpression::Variable("p".to_string()),
1587 right: LogicalExpression::Variable("c".to_string()),
1588 }],
1589 });
1590
1591 let cardinality = estimator.estimate(&join);
1592 assert!(cardinality >= 1000.0);
1594 }
1595
1596 #[test]
1597 fn test_join_semi() {
1598 let mut estimator = CardinalityEstimator::new();
1599 estimator.add_table_stats("Person", TableStats::new(1000));
1600 estimator.add_table_stats("Company", TableStats::new(100));
1601
1602 let join = LogicalOperator::Join(JoinOp {
1603 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1604 variable: "p".to_string(),
1605 label: Some("Person".to_string()),
1606 input: None,
1607 })),
1608 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1609 variable: "c".to_string(),
1610 label: Some("Company".to_string()),
1611 input: None,
1612 })),
1613 join_type: JoinType::Semi,
1614 conditions: vec![],
1615 });
1616
1617 let cardinality = estimator.estimate(&join);
1618 assert!(cardinality <= 1000.0);
1620 }
1621
1622 #[test]
1623 fn test_join_anti() {
1624 let mut estimator = CardinalityEstimator::new();
1625 estimator.add_table_stats("Person", TableStats::new(1000));
1626 estimator.add_table_stats("Company", TableStats::new(100));
1627
1628 let join = LogicalOperator::Join(JoinOp {
1629 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1630 variable: "p".to_string(),
1631 label: Some("Person".to_string()),
1632 input: None,
1633 })),
1634 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1635 variable: "c".to_string(),
1636 label: Some("Company".to_string()),
1637 input: None,
1638 })),
1639 join_type: JoinType::Anti,
1640 conditions: vec![],
1641 });
1642
1643 let cardinality = estimator.estimate(&join);
1644 assert!(cardinality <= 1000.0);
1646 assert!(cardinality >= 1.0);
1647 }
1648
1649 #[test]
1650 fn test_project_preserves_cardinality() {
1651 let mut estimator = CardinalityEstimator::new();
1652 estimator.add_table_stats("Person", TableStats::new(1000));
1653
1654 let project = LogicalOperator::Project(ProjectOp {
1655 projections: vec![Projection {
1656 expression: LogicalExpression::Variable("n".to_string()),
1657 alias: None,
1658 }],
1659 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1660 variable: "n".to_string(),
1661 label: Some("Person".to_string()),
1662 input: None,
1663 })),
1664 });
1665
1666 let cardinality = estimator.estimate(&project);
1667 assert!((cardinality - 1000.0).abs() < 0.001);
1668 }
1669
1670 #[test]
1671 fn test_sort_preserves_cardinality() {
1672 let mut estimator = CardinalityEstimator::new();
1673 estimator.add_table_stats("Person", TableStats::new(1000));
1674
1675 let sort = LogicalOperator::Sort(SortOp {
1676 keys: vec![SortKey {
1677 expression: LogicalExpression::Variable("n".to_string()),
1678 order: SortOrder::Ascending,
1679 }],
1680 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1681 variable: "n".to_string(),
1682 label: Some("Person".to_string()),
1683 input: None,
1684 })),
1685 });
1686
1687 let cardinality = estimator.estimate(&sort);
1688 assert!((cardinality - 1000.0).abs() < 0.001);
1689 }
1690
1691 #[test]
1692 fn test_distinct_reduces_cardinality() {
1693 let mut estimator = CardinalityEstimator::new();
1694 estimator.add_table_stats("Person", TableStats::new(1000));
1695
1696 let distinct = LogicalOperator::Distinct(DistinctOp {
1697 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1698 variable: "n".to_string(),
1699 label: Some("Person".to_string()),
1700 input: None,
1701 })),
1702 columns: None,
1703 });
1704
1705 let cardinality = estimator.estimate(&distinct);
1706 assert!((cardinality - 500.0).abs() < 0.001);
1708 }
1709
1710 #[test]
1711 fn test_skip_reduces_cardinality() {
1712 let mut estimator = CardinalityEstimator::new();
1713 estimator.add_table_stats("Person", TableStats::new(1000));
1714
1715 let skip = LogicalOperator::Skip(SkipOp {
1716 count: 100,
1717 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1718 variable: "n".to_string(),
1719 label: Some("Person".to_string()),
1720 input: None,
1721 })),
1722 });
1723
1724 let cardinality = estimator.estimate(&skip);
1725 assert!((cardinality - 900.0).abs() < 0.001);
1726 }
1727
1728 #[test]
1729 fn test_return_preserves_cardinality() {
1730 let mut estimator = CardinalityEstimator::new();
1731 estimator.add_table_stats("Person", TableStats::new(1000));
1732
1733 let ret = LogicalOperator::Return(ReturnOp {
1734 items: vec![ReturnItem {
1735 expression: LogicalExpression::Variable("n".to_string()),
1736 alias: None,
1737 }],
1738 distinct: false,
1739 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1740 variable: "n".to_string(),
1741 label: Some("Person".to_string()),
1742 input: None,
1743 })),
1744 });
1745
1746 let cardinality = estimator.estimate(&ret);
1747 assert!((cardinality - 1000.0).abs() < 0.001);
1748 }
1749
1750 #[test]
1751 fn test_empty_cardinality() {
1752 let estimator = CardinalityEstimator::new();
1753 let cardinality = estimator.estimate(&LogicalOperator::Empty);
1754 assert!((cardinality).abs() < 0.001);
1755 }
1756
1757 #[test]
1758 fn test_table_stats_with_column() {
1759 let stats = TableStats::new(1000).with_column(
1760 "age",
1761 ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1762 );
1763
1764 assert_eq!(stats.row_count, 1000);
1765 let col = stats.columns.get("age").unwrap();
1766 assert_eq!(col.distinct_count, 50);
1767 assert_eq!(col.null_count, 10);
1768 assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1769 assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1770 }
1771
1772 #[test]
1773 fn test_estimator_default() {
1774 let estimator = CardinalityEstimator::default();
1775 let scan = LogicalOperator::NodeScan(NodeScanOp {
1776 variable: "n".to_string(),
1777 label: None,
1778 input: None,
1779 });
1780 let cardinality = estimator.estimate(&scan);
1781 assert!((cardinality - 1000.0).abs() < 0.001);
1782 }
1783
1784 #[test]
1785 fn test_set_avg_fanout() {
1786 let mut estimator = CardinalityEstimator::new();
1787 estimator.add_table_stats("Person", TableStats::new(100));
1788 estimator.set_avg_fanout(5.0);
1789
1790 let expand = LogicalOperator::Expand(ExpandOp {
1791 from_variable: "a".to_string(),
1792 to_variable: "b".to_string(),
1793 edge_variable: None,
1794 direction: ExpandDirection::Outgoing,
1795 edge_types: vec![],
1796 min_hops: 1,
1797 max_hops: Some(1),
1798 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1799 variable: "a".to_string(),
1800 label: Some("Person".to_string()),
1801 input: None,
1802 })),
1803 path_alias: None,
1804 path_mode: PathMode::Walk,
1805 });
1806
1807 let cardinality = estimator.estimate(&expand);
1808 assert!((cardinality - 500.0).abs() < 0.001);
1810 }
1811
1812 #[test]
1813 fn test_multiple_group_by_keys_reduce_cardinality() {
1814 let mut estimator = CardinalityEstimator::new();
1818 estimator.add_table_stats("Person", TableStats::new(10000));
1819
1820 let single_group = LogicalOperator::Aggregate(AggregateOp {
1821 group_by: vec![LogicalExpression::Property {
1822 variable: "n".to_string(),
1823 property: "city".to_string(),
1824 }],
1825 aggregates: vec![],
1826 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1827 variable: "n".to_string(),
1828 label: Some("Person".to_string()),
1829 input: None,
1830 })),
1831 having: None,
1832 });
1833
1834 let multi_group = LogicalOperator::Aggregate(AggregateOp {
1835 group_by: vec![
1836 LogicalExpression::Property {
1837 variable: "n".to_string(),
1838 property: "city".to_string(),
1839 },
1840 LogicalExpression::Property {
1841 variable: "n".to_string(),
1842 property: "country".to_string(),
1843 },
1844 ],
1845 aggregates: vec![],
1846 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1847 variable: "n".to_string(),
1848 label: Some("Person".to_string()),
1849 input: None,
1850 })),
1851 having: None,
1852 });
1853
1854 let single_card = estimator.estimate(&single_group);
1855 let multi_card = estimator.estimate(&multi_group);
1856
1857 assert!(single_card < 10000.0);
1859 assert!(multi_card < 10000.0);
1860 assert!(single_card >= 1.0);
1862 assert!(multi_card >= 1.0);
1863 }
1864
1865 #[test]
1868 fn test_histogram_build_uniform() {
1869 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1871 let histogram = EquiDepthHistogram::build(&values, 10);
1872
1873 assert_eq!(histogram.num_buckets(), 10);
1874 assert_eq!(histogram.total_rows(), 100);
1875
1876 for bucket in histogram.buckets() {
1878 assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1879 }
1880 }
1881
1882 #[test]
1883 fn test_histogram_build_skewed() {
1884 let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1886 values.extend((0..20).map(|i| 1000.0 + i as f64));
1887 let histogram = EquiDepthHistogram::build(&values, 5);
1888
1889 assert_eq!(histogram.num_buckets(), 5);
1890 assert_eq!(histogram.total_rows(), 100);
1891
1892 for bucket in histogram.buckets() {
1894 assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1895 }
1896 }
1897
1898 #[test]
1899 fn test_histogram_range_selectivity_full() {
1900 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1901 let histogram = EquiDepthHistogram::build(&values, 10);
1902
1903 let selectivity = histogram.range_selectivity(None, None);
1905 assert!((selectivity - 1.0).abs() < 0.01);
1906 }
1907
1908 #[test]
1909 fn test_histogram_range_selectivity_half() {
1910 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1911 let histogram = EquiDepthHistogram::build(&values, 10);
1912
1913 let selectivity = histogram.range_selectivity(Some(50.0), None);
1915 assert!(selectivity > 0.4 && selectivity < 0.6);
1916 }
1917
1918 #[test]
1919 fn test_histogram_range_selectivity_quarter() {
1920 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1921 let histogram = EquiDepthHistogram::build(&values, 10);
1922
1923 let selectivity = histogram.range_selectivity(None, Some(25.0));
1925 assert!(selectivity > 0.2 && selectivity < 0.3);
1926 }
1927
1928 #[test]
1929 fn test_histogram_equality_selectivity() {
1930 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1931 let histogram = EquiDepthHistogram::build(&values, 10);
1932
1933 let selectivity = histogram.equality_selectivity(50.0);
1935 assert!(selectivity > 0.005 && selectivity < 0.02);
1936 }
1937
1938 #[test]
1939 fn test_histogram_empty() {
1940 let histogram = EquiDepthHistogram::build(&[], 10);
1941
1942 assert_eq!(histogram.num_buckets(), 0);
1943 assert_eq!(histogram.total_rows(), 0);
1944
1945 let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1947 assert!((selectivity - 0.33).abs() < 0.01);
1948 }
1949
1950 #[test]
1951 fn test_histogram_bucket_overlap() {
1952 let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1953
1954 assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1956
1957 assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1959
1960 assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1962
1963 assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1965
1966 assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1968 }
1969
1970 #[test]
1971 fn test_column_stats_from_values() {
1972 let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
1973 let stats = ColumnStats::from_values(values, 4);
1974
1975 assert_eq!(stats.distinct_count, 5); assert!(stats.min_value.is_some());
1977 assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
1978 assert!(stats.max_value.is_some());
1979 assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
1980 assert!(stats.histogram.is_some());
1981 }
1982
1983 #[test]
1984 fn test_column_stats_with_histogram_builder() {
1985 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1986 let histogram = EquiDepthHistogram::build(&values, 10);
1987
1988 let stats = ColumnStats::new(100)
1989 .with_range(0.0, 99.0)
1990 .with_histogram(histogram);
1991
1992 assert!(stats.histogram.is_some());
1993 assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
1994 }
1995
1996 #[test]
1997 fn test_filter_with_histogram_stats() {
1998 let mut estimator = CardinalityEstimator::new();
1999
2000 let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2002 let histogram = EquiDepthHistogram::build(&age_values, 10);
2003 let age_stats = ColumnStats::new(62)
2004 .with_range(18.0, 79.0)
2005 .with_histogram(histogram);
2006
2007 estimator.add_table_stats(
2008 "Person",
2009 TableStats::new(1000).with_column("age", age_stats),
2010 );
2011
2012 let filter = LogicalOperator::Filter(FilterOp {
2015 predicate: LogicalExpression::Binary {
2016 left: Box::new(LogicalExpression::Property {
2017 variable: "n".to_string(),
2018 property: "age".to_string(),
2019 }),
2020 op: BinaryOp::Gt,
2021 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2022 },
2023 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2024 variable: "n".to_string(),
2025 label: Some("Person".to_string()),
2026 input: None,
2027 })),
2028 });
2029
2030 let cardinality = estimator.estimate(&filter);
2031
2032 assert!(cardinality > 300.0 && cardinality < 600.0);
2035 }
2036
2037 #[test]
2038 fn test_filter_equality_with_histogram() {
2039 let mut estimator = CardinalityEstimator::new();
2040
2041 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2043 let histogram = EquiDepthHistogram::build(&values, 10);
2044 let stats = ColumnStats::new(100)
2045 .with_range(0.0, 99.0)
2046 .with_histogram(histogram);
2047
2048 estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2049
2050 let filter = LogicalOperator::Filter(FilterOp {
2052 predicate: LogicalExpression::Binary {
2053 left: Box::new(LogicalExpression::Property {
2054 variable: "d".to_string(),
2055 property: "value".to_string(),
2056 }),
2057 op: BinaryOp::Eq,
2058 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2059 },
2060 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2061 variable: "d".to_string(),
2062 label: Some("Data".to_string()),
2063 input: None,
2064 })),
2065 });
2066
2067 let cardinality = estimator.estimate(&filter);
2068
2069 assert!((1.0..50.0).contains(&cardinality));
2072 }
2073
2074 #[test]
2075 fn test_histogram_min_max() {
2076 let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2077 let histogram = EquiDepthHistogram::build(&values, 2);
2078
2079 assert_eq!(histogram.min_value(), Some(5.0));
2080 assert!(histogram.max_value().is_some());
2082 }
2083
2084 #[test]
2087 fn test_selectivity_config_defaults() {
2088 let config = SelectivityConfig::new();
2089 assert!((config.default - 0.1).abs() < f64::EPSILON);
2090 assert!((config.equality - 0.01).abs() < f64::EPSILON);
2091 assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2092 assert!((config.range - 0.33).abs() < f64::EPSILON);
2093 assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2094 assert!((config.membership - 0.1).abs() < f64::EPSILON);
2095 assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2096 assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2097 assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2098 }
2099
2100 #[test]
2101 fn test_custom_selectivity_config() {
2102 let config = SelectivityConfig {
2103 equality: 0.05,
2104 range: 0.25,
2105 ..SelectivityConfig::new()
2106 };
2107 let estimator = CardinalityEstimator::with_selectivity_config(config);
2108 assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2109 assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2110 }
2111
2112 #[test]
2113 fn test_custom_selectivity_affects_estimation() {
2114 let mut default_est = CardinalityEstimator::new();
2116 default_est.add_table_stats("Person", TableStats::new(1000));
2117
2118 let filter = LogicalOperator::Filter(FilterOp {
2119 predicate: LogicalExpression::Binary {
2120 left: Box::new(LogicalExpression::Property {
2121 variable: "n".to_string(),
2122 property: "name".to_string(),
2123 }),
2124 op: BinaryOp::Eq,
2125 right: Box::new(LogicalExpression::Literal(Value::String("Alice".into()))),
2126 },
2127 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2128 variable: "n".to_string(),
2129 label: Some("Person".to_string()),
2130 input: None,
2131 })),
2132 });
2133
2134 let default_card = default_est.estimate(&filter);
2135
2136 let config = SelectivityConfig {
2138 equality: 0.2,
2139 ..SelectivityConfig::new()
2140 };
2141 let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2142 custom_est.add_table_stats("Person", TableStats::new(1000));
2143
2144 let custom_card = custom_est.estimate(&filter);
2145
2146 assert!(custom_card > default_card);
2147 assert!((custom_card - 200.0).abs() < 1.0);
2148 }
2149
2150 #[test]
2151 fn test_custom_range_selectivity() {
2152 let config = SelectivityConfig {
2153 range: 0.5,
2154 ..SelectivityConfig::new()
2155 };
2156 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2157 estimator.add_table_stats("Person", TableStats::new(1000));
2158
2159 let filter = LogicalOperator::Filter(FilterOp {
2160 predicate: LogicalExpression::Binary {
2161 left: Box::new(LogicalExpression::Property {
2162 variable: "n".to_string(),
2163 property: "age".to_string(),
2164 }),
2165 op: BinaryOp::Gt,
2166 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2167 },
2168 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2169 variable: "n".to_string(),
2170 label: Some("Person".to_string()),
2171 input: None,
2172 })),
2173 });
2174
2175 let cardinality = estimator.estimate(&filter);
2176 assert!((cardinality - 500.0).abs() < 1.0);
2178 }
2179
2180 #[test]
2181 fn test_custom_distinct_fraction() {
2182 let config = SelectivityConfig {
2183 distinct_fraction: 0.8,
2184 ..SelectivityConfig::new()
2185 };
2186 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2187 estimator.add_table_stats("Person", TableStats::new(1000));
2188
2189 let distinct = LogicalOperator::Distinct(DistinctOp {
2190 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2191 variable: "n".to_string(),
2192 label: Some("Person".to_string()),
2193 input: None,
2194 })),
2195 columns: None,
2196 });
2197
2198 let cardinality = estimator.estimate(&distinct);
2199 assert!((cardinality - 800.0).abs() < 1.0);
2201 }
2202
2203 #[test]
2206 fn test_estimation_log_basic() {
2207 let mut log = EstimationLog::new(10.0);
2208 log.record("NodeScan(Person)", 1000.0, 1200.0);
2209 log.record("Filter(age > 30)", 100.0, 90.0);
2210
2211 assert_eq!(log.entries().len(), 2);
2212 assert!(!log.should_replan()); }
2214
2215 #[test]
2216 fn test_estimation_log_triggers_replan() {
2217 let mut log = EstimationLog::new(10.0);
2218 log.record("NodeScan(Person)", 100.0, 5000.0); assert!(log.should_replan());
2221 }
2222
2223 #[test]
2224 fn test_estimation_log_overestimate_triggers_replan() {
2225 let mut log = EstimationLog::new(5.0);
2226 log.record("Filter", 1000.0, 100.0); assert!(log.should_replan()); }
2230
2231 #[test]
2232 fn test_estimation_entry_error_ratio() {
2233 let entry = EstimationEntry {
2234 operator: "test".into(),
2235 estimated: 100.0,
2236 actual: 200.0,
2237 };
2238 assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2239
2240 let perfect = EstimationEntry {
2241 operator: "test".into(),
2242 estimated: 100.0,
2243 actual: 100.0,
2244 };
2245 assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2246
2247 let zero_est = EstimationEntry {
2248 operator: "test".into(),
2249 estimated: 0.0,
2250 actual: 0.0,
2251 };
2252 assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2253 }
2254
2255 #[test]
2256 fn test_estimation_log_max_error_ratio() {
2257 let mut log = EstimationLog::new(10.0);
2258 log.record("A", 100.0, 300.0); log.record("B", 100.0, 50.0); log.record("C", 100.0, 100.0); assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2263 }
2264
2265 #[test]
2266 fn test_estimation_log_clear() {
2267 let mut log = EstimationLog::new(10.0);
2268 log.record("A", 100.0, 100.0);
2269 assert_eq!(log.entries().len(), 1);
2270
2271 log.clear();
2272 assert!(log.entries().is_empty());
2273 assert!(!log.should_replan());
2274 }
2275
2276 #[test]
2277 fn test_create_estimation_log() {
2278 let log = CardinalityEstimator::create_estimation_log();
2279 assert!(log.entries().is_empty());
2280 assert!(!log.should_replan());
2281 }
2282}