1use crate::query::plan::{
18 AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19 LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
20 VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31 pub lower_bound: f64,
33 pub upper_bound: f64,
35 pub frequency: u64,
37 pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42 #[must_use]
44 pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45 Self {
46 lower_bound,
47 upper_bound,
48 frequency,
49 distinct_count,
50 }
51 }
52
53 #[must_use]
55 pub fn width(&self) -> f64 {
56 self.upper_bound - self.lower_bound
57 }
58
59 #[must_use]
61 pub fn contains(&self, value: f64) -> bool {
62 value >= self.lower_bound && value < self.upper_bound
63 }
64
65 #[must_use]
67 pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68 let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69 let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71 let bucket_width = self.width();
72 if bucket_width <= 0.0 {
73 return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74 1.0
75 } else {
76 0.0
77 };
78 }
79
80 let overlap = (effective_upper - effective_lower).max(0.0);
81 (overlap / bucket_width).min(1.0)
82 }
83}
84
85#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106 buckets: Vec<HistogramBucket>,
108 total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113 #[must_use]
115 pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116 let total_rows = buckets.iter().map(|b| b.frequency).sum();
117 Self {
118 buckets,
119 total_rows,
120 }
121 }
122
123 #[must_use]
132 pub fn build(values: &[f64], num_buckets: usize) -> Self {
133 if values.is_empty() || num_buckets == 0 {
134 return Self {
135 buckets: Vec::new(),
136 total_rows: 0,
137 };
138 }
139
140 let num_buckets = num_buckets.min(values.len());
141 let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142 let mut buckets = Vec::with_capacity(num_buckets);
143
144 let mut start_idx = 0;
145 while start_idx < values.len() {
146 let end_idx = (start_idx + rows_per_bucket).min(values.len());
147 let lower_bound = values[start_idx];
148 let upper_bound = if end_idx < values.len() {
149 values[end_idx]
150 } else {
151 values[end_idx - 1] + 1.0
153 };
154
155 let bucket_values = &values[start_idx..end_idx];
157 let distinct_count = count_distinct(bucket_values);
158
159 buckets.push(HistogramBucket::new(
160 lower_bound,
161 upper_bound,
162 (end_idx - start_idx) as u64,
163 distinct_count,
164 ));
165
166 start_idx = end_idx;
167 }
168
169 Self::new(buckets)
170 }
171
172 #[must_use]
174 pub fn num_buckets(&self) -> usize {
175 self.buckets.len()
176 }
177
178 #[must_use]
180 pub fn total_rows(&self) -> u64 {
181 self.total_rows
182 }
183
184 #[must_use]
186 pub fn buckets(&self) -> &[HistogramBucket] {
187 &self.buckets
188 }
189
190 #[must_use]
199 pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200 if self.buckets.is_empty() || self.total_rows == 0 {
201 return 0.33; }
203
204 let mut matching_rows = 0.0;
205
206 for bucket in &self.buckets {
207 let bucket_lower = bucket.lower_bound;
209 let bucket_upper = bucket.upper_bound;
210
211 if let Some(l) = lower
213 && bucket_upper <= l
214 {
215 continue;
216 }
217 if let Some(u) = upper
218 && bucket_lower >= u
219 {
220 continue;
221 }
222
223 let overlap = bucket.overlap_fraction(lower, upper);
225 matching_rows += overlap * bucket.frequency as f64;
226 }
227
228 (matching_rows / self.total_rows as f64).min(1.0).max(0.0)
229 }
230
231 #[must_use]
235 pub fn equality_selectivity(&self, value: f64) -> f64 {
236 if self.buckets.is_empty() || self.total_rows == 0 {
237 return 0.01; }
239
240 for bucket in &self.buckets {
242 if bucket.contains(value) {
243 if bucket.distinct_count > 0 {
245 return (bucket.frequency as f64
246 / bucket.distinct_count as f64
247 / self.total_rows as f64)
248 .min(1.0);
249 }
250 }
251 }
252
253 0.001
255 }
256
257 #[must_use]
259 pub fn min_value(&self) -> Option<f64> {
260 self.buckets.first().map(|b| b.lower_bound)
261 }
262
263 #[must_use]
265 pub fn max_value(&self) -> Option<f64> {
266 self.buckets.last().map(|b| b.upper_bound)
267 }
268}
269
270fn count_distinct(sorted_values: &[f64]) -> u64 {
272 if sorted_values.is_empty() {
273 return 0;
274 }
275
276 let mut count = 1u64;
277 let mut prev = sorted_values[0];
278
279 for &val in &sorted_values[1..] {
280 if (val - prev).abs() > f64::EPSILON {
281 count += 1;
282 prev = val;
283 }
284 }
285
286 count
287}
288
289#[derive(Debug, Clone)]
291pub struct TableStats {
292 pub row_count: u64,
294 pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299 #[must_use]
301 pub fn new(row_count: u64) -> Self {
302 Self {
303 row_count,
304 columns: HashMap::new(),
305 }
306 }
307
308 pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310 self.columns.insert(name.to_string(), stats);
311 self
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct ColumnStats {
318 pub distinct_count: u64,
320 pub null_count: u64,
322 pub min_value: Option<f64>,
324 pub max_value: Option<f64>,
326 pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331 #[must_use]
333 pub fn new(distinct_count: u64) -> Self {
334 Self {
335 distinct_count,
336 null_count: 0,
337 min_value: None,
338 max_value: None,
339 histogram: None,
340 }
341 }
342
343 #[must_use]
345 pub fn with_nulls(mut self, null_count: u64) -> Self {
346 self.null_count = null_count;
347 self
348 }
349
350 #[must_use]
352 pub fn with_range(mut self, min: f64, max: f64) -> Self {
353 self.min_value = Some(min);
354 self.max_value = Some(max);
355 self
356 }
357
358 #[must_use]
360 pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361 self.histogram = Some(histogram);
362 self
363 }
364
365 #[must_use]
373 pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374 if values.is_empty() {
375 return Self::new(0);
376 }
377
378 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381 let min = values.first().copied();
382 let max = values.last().copied();
383 let distinct_count = count_distinct(&values);
384 let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386 Self {
387 distinct_count,
388 null_count: 0,
389 min_value: min,
390 max_value: max,
391 histogram: Some(histogram),
392 }
393 }
394}
395
396#[derive(Debug, Clone)]
402pub struct SelectivityConfig {
403 pub default: f64,
405 pub equality: f64,
407 pub inequality: f64,
409 pub range: f64,
411 pub string_ops: f64,
413 pub membership: f64,
415 pub is_null: f64,
417 pub is_not_null: f64,
419 pub distinct_fraction: f64,
421}
422
423impl SelectivityConfig {
424 #[must_use]
426 pub fn new() -> Self {
427 Self {
428 default: 0.1,
429 equality: 0.01,
430 inequality: 0.99,
431 range: 0.33,
432 string_ops: 0.1,
433 membership: 0.1,
434 is_null: 0.05,
435 is_not_null: 0.95,
436 distinct_fraction: 0.5,
437 }
438 }
439}
440
441impl Default for SelectivityConfig {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447#[derive(Debug, Clone)]
449pub struct EstimationEntry {
450 pub operator: String,
452 pub estimated: f64,
454 pub actual: f64,
456}
457
458impl EstimationEntry {
459 #[must_use]
465 pub fn error_ratio(&self) -> f64 {
466 if self.estimated.abs() < f64::EPSILON {
467 if self.actual.abs() < f64::EPSILON {
468 1.0
469 } else {
470 f64::INFINITY
471 }
472 } else {
473 self.actual / self.estimated
474 }
475 }
476}
477
478#[derive(Debug, Clone, Default)]
485pub struct EstimationLog {
486 entries: Vec<EstimationEntry>,
488 replan_threshold: f64,
492}
493
494impl EstimationLog {
495 #[must_use]
497 pub fn new(replan_threshold: f64) -> Self {
498 Self {
499 entries: Vec::new(),
500 replan_threshold,
501 }
502 }
503
504 pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
506 self.entries.push(EstimationEntry {
507 operator: operator.into(),
508 estimated,
509 actual,
510 });
511 }
512
513 #[must_use]
515 pub fn entries(&self) -> &[EstimationEntry] {
516 &self.entries
517 }
518
519 #[must_use]
522 pub fn should_replan(&self) -> bool {
523 self.entries.iter().any(|e| {
524 let ratio = e.error_ratio();
525 ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
526 })
527 }
528
529 #[must_use]
531 pub fn max_error_ratio(&self) -> f64 {
532 self.entries
533 .iter()
534 .map(|e| {
535 let r = e.error_ratio();
536 if r < 1.0 { 1.0 / r } else { r }
538 })
539 .fold(1.0_f64, f64::max)
540 }
541
542 pub fn clear(&mut self) {
544 self.entries.clear();
545 }
546}
547
548pub struct CardinalityEstimator {
550 table_stats: HashMap<String, TableStats>,
552 default_row_count: u64,
554 default_selectivity: f64,
556 avg_fanout: f64,
558 selectivity_config: SelectivityConfig,
560}
561
562impl CardinalityEstimator {
563 #[must_use]
565 pub fn new() -> Self {
566 let config = SelectivityConfig::new();
567 Self {
568 table_stats: HashMap::new(),
569 default_row_count: 1000,
570 default_selectivity: config.default,
571 avg_fanout: 10.0,
572 selectivity_config: config,
573 }
574 }
575
576 #[must_use]
578 pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
579 Self {
580 table_stats: HashMap::new(),
581 default_row_count: 1000,
582 default_selectivity: config.default,
583 avg_fanout: 10.0,
584 selectivity_config: config,
585 }
586 }
587
588 #[must_use]
590 pub fn selectivity_config(&self) -> &SelectivityConfig {
591 &self.selectivity_config
592 }
593
594 #[must_use]
596 pub fn create_estimation_log() -> EstimationLog {
597 EstimationLog::new(10.0)
598 }
599
600 #[must_use]
606 pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
607 let mut estimator = Self::new();
608
609 if stats.total_nodes > 0 {
611 estimator.default_row_count = stats.total_nodes;
612 }
613
614 for (label, label_stats) in &stats.labels {
616 let mut table_stats = TableStats::new(label_stats.node_count);
617
618 for (prop, col_stats) in &label_stats.properties {
620 let optimizer_col =
621 ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
622 table_stats = table_stats.with_column(prop, optimizer_col);
623 }
624
625 estimator.add_table_stats(label, table_stats);
626 }
627
628 if !stats.edge_types.is_empty() {
630 let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
631 estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
632 } else if stats.total_nodes > 0 {
633 estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
634 }
635
636 if estimator.avg_fanout < 1.0 {
638 estimator.avg_fanout = 1.0;
639 }
640
641 estimator
642 }
643
644 pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
646 self.table_stats.insert(name.to_string(), stats);
647 }
648
649 pub fn set_avg_fanout(&mut self, fanout: f64) {
651 self.avg_fanout = fanout;
652 }
653
654 #[must_use]
656 pub fn estimate(&self, op: &LogicalOperator) -> f64 {
657 match op {
658 LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
659 LogicalOperator::Filter(filter) => self.estimate_filter(filter),
660 LogicalOperator::Project(project) => self.estimate_project(project),
661 LogicalOperator::Expand(expand) => self.estimate_expand(expand),
662 LogicalOperator::Join(join) => self.estimate_join(join),
663 LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
664 LogicalOperator::Sort(sort) => self.estimate_sort(sort),
665 LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
666 LogicalOperator::Limit(limit) => self.estimate_limit(limit),
667 LogicalOperator::Skip(skip) => self.estimate_skip(skip),
668 LogicalOperator::Return(ret) => self.estimate(&ret.input),
669 LogicalOperator::Empty => 0.0,
670 LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
671 LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
672 _ => self.default_row_count as f64,
673 }
674 }
675
676 fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
678 if let Some(label) = &scan.label
679 && let Some(stats) = self.table_stats.get(label)
680 {
681 return stats.row_count as f64;
682 }
683 self.default_row_count as f64
685 }
686
687 fn estimate_filter(&self, filter: &FilterOp) -> f64 {
689 let input_cardinality = self.estimate(&filter.input);
690 let selectivity = self.estimate_selectivity(&filter.predicate);
691 (input_cardinality * selectivity).max(1.0)
692 }
693
694 fn estimate_project(&self, project: &ProjectOp) -> f64 {
696 self.estimate(&project.input)
697 }
698
699 fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
701 let input_cardinality = self.estimate(&expand.input);
702
703 let fanout = if expand.edge_type.is_some() {
705 self.avg_fanout * 0.5
707 } else {
708 self.avg_fanout
709 };
710
711 let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
713 let min = expand.min_hops as f64;
714 let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
715 (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
717 } else {
718 fanout
719 };
720
721 (input_cardinality * path_multiplier).max(1.0)
722 }
723
724 fn estimate_join(&self, join: &JoinOp) -> f64 {
726 let left_card = self.estimate(&join.left);
727 let right_card = self.estimate(&join.right);
728
729 match join.join_type {
730 JoinType::Cross => left_card * right_card,
731 JoinType::Inner => {
732 let selectivity = if join.conditions.is_empty() {
734 1.0 } else {
736 0.1_f64.powi(join.conditions.len() as i32)
738 };
739 (left_card * right_card * selectivity).max(1.0)
740 }
741 JoinType::Left => {
742 let inner_card = self.estimate_join(&JoinOp {
744 left: join.left.clone(),
745 right: join.right.clone(),
746 join_type: JoinType::Inner,
747 conditions: join.conditions.clone(),
748 });
749 inner_card.max(left_card)
750 }
751 JoinType::Right => {
752 let inner_card = self.estimate_join(&JoinOp {
754 left: join.left.clone(),
755 right: join.right.clone(),
756 join_type: JoinType::Inner,
757 conditions: join.conditions.clone(),
758 });
759 inner_card.max(right_card)
760 }
761 JoinType::Full => {
762 let inner_card = self.estimate_join(&JoinOp {
764 left: join.left.clone(),
765 right: join.right.clone(),
766 join_type: JoinType::Inner,
767 conditions: join.conditions.clone(),
768 });
769 inner_card.max(left_card.max(right_card))
770 }
771 JoinType::Semi => {
772 (left_card * self.default_selectivity).max(1.0)
774 }
775 JoinType::Anti => {
776 (left_card * (1.0 - self.default_selectivity)).max(1.0)
778 }
779 }
780 }
781
782 fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
784 let input_cardinality = self.estimate(&agg.input);
785
786 if agg.group_by.is_empty() {
787 1.0
789 } else {
790 let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
793 (input_cardinality / group_reduction).max(1.0)
794 }
795 }
796
797 fn estimate_sort(&self, sort: &SortOp) -> f64 {
799 self.estimate(&sort.input)
800 }
801
802 fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
804 let input_cardinality = self.estimate(&distinct.input);
805 (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
806 }
807
808 fn estimate_limit(&self, limit: &LimitOp) -> f64 {
810 let input_cardinality = self.estimate(&limit.input);
811 (limit.count as f64).min(input_cardinality)
812 }
813
814 fn estimate_skip(&self, skip: &SkipOp) -> f64 {
816 let input_cardinality = self.estimate(&skip.input);
817 (input_cardinality - skip.count as f64).max(0.0)
818 }
819
820 fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
825 let base_k = scan.k as f64;
826
827 let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
829 0.7
831 } else {
832 1.0
833 };
834
835 (base_k * selectivity).max(1.0)
836 }
837
838 fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
842 let input_cardinality = self.estimate(&join.input);
843 let k = join.k as f64;
844
845 let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
847 0.7
848 } else {
849 1.0
850 };
851
852 (input_cardinality * k * selectivity).max(1.0)
853 }
854
855 fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
857 match expr {
858 LogicalExpression::Binary { left, op, right } => {
859 self.estimate_binary_selectivity(left, *op, right)
860 }
861 LogicalExpression::Unary { op, operand } => {
862 self.estimate_unary_selectivity(*op, operand)
863 }
864 LogicalExpression::Literal(value) => {
865 if let grafeo_common::types::Value::Bool(b) = value {
867 if *b { 1.0 } else { 0.0 }
868 } else {
869 self.default_selectivity
870 }
871 }
872 _ => self.default_selectivity,
873 }
874 }
875
876 fn estimate_binary_selectivity(
878 &self,
879 left: &LogicalExpression,
880 op: BinaryOp,
881 right: &LogicalExpression,
882 ) -> f64 {
883 match op {
884 BinaryOp::Eq => {
886 if let Some(selectivity) = self.try_equality_selectivity(left, right) {
887 return selectivity;
888 }
889 self.selectivity_config.equality
890 }
891 BinaryOp::Ne => self.selectivity_config.inequality,
893 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
895 if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
896 return selectivity;
897 }
898 self.selectivity_config.range
899 }
900 BinaryOp::And => {
902 let left_sel = self.estimate_selectivity(left);
903 let right_sel = self.estimate_selectivity(right);
904 left_sel * right_sel
906 }
907 BinaryOp::Or => {
908 let left_sel = self.estimate_selectivity(left);
909 let right_sel = self.estimate_selectivity(right);
910 (left_sel + right_sel - left_sel * right_sel).min(1.0)
913 }
914 BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
916 self.selectivity_config.string_ops
917 }
918 BinaryOp::In => self.selectivity_config.membership,
920 _ => self.default_selectivity,
922 }
923 }
924
925 fn try_equality_selectivity(
927 &self,
928 left: &LogicalExpression,
929 right: &LogicalExpression,
930 ) -> Option<f64> {
931 let (label, column, value) = self.extract_column_and_value(left, right)?;
933
934 let stats = self.get_column_stats(&label, &column)?;
936
937 if let Some(ref histogram) = stats.histogram {
939 return Some(histogram.equality_selectivity(value));
940 }
941
942 if stats.distinct_count > 0 {
944 return Some(1.0 / stats.distinct_count as f64);
945 }
946
947 None
948 }
949
950 fn try_range_selectivity(
952 &self,
953 left: &LogicalExpression,
954 op: BinaryOp,
955 right: &LogicalExpression,
956 ) -> Option<f64> {
957 let (label, column, value) = self.extract_column_and_value(left, right)?;
959
960 let stats = self.get_column_stats(&label, &column)?;
962
963 let (lower, upper) = match op {
965 BinaryOp::Lt => (None, Some(value)),
966 BinaryOp::Le => (None, Some(value + f64::EPSILON)),
967 BinaryOp::Gt => (Some(value + f64::EPSILON), None),
968 BinaryOp::Ge => (Some(value), None),
969 _ => return None,
970 };
971
972 if let Some(ref histogram) = stats.histogram {
974 return Some(histogram.range_selectivity(lower, upper));
975 }
976
977 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
979 let range = max - min;
980 if range <= 0.0 {
981 return Some(1.0);
982 }
983
984 let effective_lower = lower.unwrap_or(min).max(min);
985 let effective_upper = upper.unwrap_or(max).min(max);
986 let overlap = (effective_upper - effective_lower).max(0.0);
987 return Some((overlap / range).min(1.0).max(0.0));
988 }
989
990 None
991 }
992
993 fn extract_column_and_value(
998 &self,
999 left: &LogicalExpression,
1000 right: &LogicalExpression,
1001 ) -> Option<(String, String, f64)> {
1002 if let Some(result) = self.try_extract_property_literal(left, right) {
1004 return Some(result);
1005 }
1006
1007 self.try_extract_property_literal(right, left)
1009 }
1010
1011 fn try_extract_property_literal(
1013 &self,
1014 property_expr: &LogicalExpression,
1015 literal_expr: &LogicalExpression,
1016 ) -> Option<(String, String, f64)> {
1017 let (variable, property) = match property_expr {
1019 LogicalExpression::Property { variable, property } => {
1020 (variable.clone(), property.clone())
1021 }
1022 _ => return None,
1023 };
1024
1025 let value = match literal_expr {
1027 LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1028 LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1029 _ => return None,
1030 };
1031
1032 for label in self.table_stats.keys() {
1036 if let Some(stats) = self.table_stats.get(label)
1037 && stats.columns.contains_key(&property)
1038 {
1039 return Some((label.clone(), property, value));
1040 }
1041 }
1042
1043 Some((variable, property, value))
1045 }
1046
1047 fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1049 match op {
1050 UnaryOp::Not => 1.0 - self.default_selectivity,
1051 UnaryOp::IsNull => self.selectivity_config.is_null,
1052 UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1053 UnaryOp::Neg => 1.0, }
1055 }
1056
1057 fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1059 self.table_stats.get(label)?.columns.get(column)
1060 }
1061
1062 #[allow(dead_code)]
1064 fn estimate_equality_with_stats(&self, label: &str, column: &str) -> f64 {
1065 if let Some(stats) = self.get_column_stats(label, column)
1066 && stats.distinct_count > 0
1067 {
1068 return 1.0 / stats.distinct_count as f64;
1069 }
1070 self.selectivity_config.equality
1071 }
1072
1073 #[allow(dead_code)]
1075 fn estimate_range_with_stats(
1076 &self,
1077 label: &str,
1078 column: &str,
1079 lower: Option<f64>,
1080 upper: Option<f64>,
1081 ) -> f64 {
1082 if let Some(stats) = self.get_column_stats(label, column)
1083 && let (Some(min), Some(max)) = (stats.min_value, stats.max_value)
1084 {
1085 let range = max - min;
1086 if range <= 0.0 {
1087 return 1.0;
1088 }
1089
1090 let effective_lower = lower.unwrap_or(min).max(min);
1091 let effective_upper = upper.unwrap_or(max).min(max);
1092
1093 let overlap = (effective_upper - effective_lower).max(0.0);
1094 return (overlap / range).min(1.0).max(0.0);
1095 }
1096 self.selectivity_config.range
1097 }
1098}
1099
1100impl Default for CardinalityEstimator {
1101 fn default() -> Self {
1102 Self::new()
1103 }
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108 use super::*;
1109 use crate::query::plan::{
1110 DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, ProjectOp,
1111 Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1112 };
1113 use grafeo_common::types::Value;
1114
1115 #[test]
1116 fn test_node_scan_with_stats() {
1117 let mut estimator = CardinalityEstimator::new();
1118 estimator.add_table_stats("Person", TableStats::new(5000));
1119
1120 let scan = LogicalOperator::NodeScan(NodeScanOp {
1121 variable: "n".to_string(),
1122 label: Some("Person".to_string()),
1123 input: None,
1124 });
1125
1126 let cardinality = estimator.estimate(&scan);
1127 assert!((cardinality - 5000.0).abs() < 0.001);
1128 }
1129
1130 #[test]
1131 fn test_filter_reduces_cardinality() {
1132 let mut estimator = CardinalityEstimator::new();
1133 estimator.add_table_stats("Person", TableStats::new(1000));
1134
1135 let filter = LogicalOperator::Filter(FilterOp {
1136 predicate: LogicalExpression::Binary {
1137 left: Box::new(LogicalExpression::Property {
1138 variable: "n".to_string(),
1139 property: "age".to_string(),
1140 }),
1141 op: BinaryOp::Eq,
1142 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1143 },
1144 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1145 variable: "n".to_string(),
1146 label: Some("Person".to_string()),
1147 input: None,
1148 })),
1149 });
1150
1151 let cardinality = estimator.estimate(&filter);
1152 assert!(cardinality < 1000.0);
1154 assert!(cardinality >= 1.0);
1155 }
1156
1157 #[test]
1158 fn test_join_cardinality() {
1159 let mut estimator = CardinalityEstimator::new();
1160 estimator.add_table_stats("Person", TableStats::new(1000));
1161 estimator.add_table_stats("Company", TableStats::new(100));
1162
1163 let join = LogicalOperator::Join(JoinOp {
1164 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1165 variable: "p".to_string(),
1166 label: Some("Person".to_string()),
1167 input: None,
1168 })),
1169 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1170 variable: "c".to_string(),
1171 label: Some("Company".to_string()),
1172 input: None,
1173 })),
1174 join_type: JoinType::Inner,
1175 conditions: vec![JoinCondition {
1176 left: LogicalExpression::Property {
1177 variable: "p".to_string(),
1178 property: "company_id".to_string(),
1179 },
1180 right: LogicalExpression::Property {
1181 variable: "c".to_string(),
1182 property: "id".to_string(),
1183 },
1184 }],
1185 });
1186
1187 let cardinality = estimator.estimate(&join);
1188 assert!(cardinality < 1000.0 * 100.0);
1190 }
1191
1192 #[test]
1193 fn test_limit_caps_cardinality() {
1194 let mut estimator = CardinalityEstimator::new();
1195 estimator.add_table_stats("Person", TableStats::new(1000));
1196
1197 let limit = LogicalOperator::Limit(LimitOp {
1198 count: 10,
1199 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1200 variable: "n".to_string(),
1201 label: Some("Person".to_string()),
1202 input: None,
1203 })),
1204 });
1205
1206 let cardinality = estimator.estimate(&limit);
1207 assert!((cardinality - 10.0).abs() < 0.001);
1208 }
1209
1210 #[test]
1211 fn test_aggregate_reduces_cardinality() {
1212 let mut estimator = CardinalityEstimator::new();
1213 estimator.add_table_stats("Person", TableStats::new(1000));
1214
1215 let global_agg = LogicalOperator::Aggregate(AggregateOp {
1217 group_by: vec![],
1218 aggregates: vec![],
1219 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1220 variable: "n".to_string(),
1221 label: Some("Person".to_string()),
1222 input: None,
1223 })),
1224 having: None,
1225 });
1226
1227 let cardinality = estimator.estimate(&global_agg);
1228 assert!((cardinality - 1.0).abs() < 0.001);
1229
1230 let group_agg = LogicalOperator::Aggregate(AggregateOp {
1232 group_by: vec![LogicalExpression::Property {
1233 variable: "n".to_string(),
1234 property: "city".to_string(),
1235 }],
1236 aggregates: vec![],
1237 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1238 variable: "n".to_string(),
1239 label: Some("Person".to_string()),
1240 input: None,
1241 })),
1242 having: None,
1243 });
1244
1245 let cardinality = estimator.estimate(&group_agg);
1246 assert!(cardinality < 1000.0);
1248 }
1249
1250 #[test]
1251 fn test_node_scan_without_stats() {
1252 let estimator = CardinalityEstimator::new();
1253
1254 let scan = LogicalOperator::NodeScan(NodeScanOp {
1255 variable: "n".to_string(),
1256 label: Some("Unknown".to_string()),
1257 input: None,
1258 });
1259
1260 let cardinality = estimator.estimate(&scan);
1261 assert!((cardinality - 1000.0).abs() < 0.001);
1263 }
1264
1265 #[test]
1266 fn test_node_scan_no_label() {
1267 let estimator = CardinalityEstimator::new();
1268
1269 let scan = LogicalOperator::NodeScan(NodeScanOp {
1270 variable: "n".to_string(),
1271 label: None,
1272 input: None,
1273 });
1274
1275 let cardinality = estimator.estimate(&scan);
1276 assert!((cardinality - 1000.0).abs() < 0.001);
1278 }
1279
1280 #[test]
1281 fn test_filter_inequality_selectivity() {
1282 let mut estimator = CardinalityEstimator::new();
1283 estimator.add_table_stats("Person", TableStats::new(1000));
1284
1285 let filter = LogicalOperator::Filter(FilterOp {
1286 predicate: LogicalExpression::Binary {
1287 left: Box::new(LogicalExpression::Property {
1288 variable: "n".to_string(),
1289 property: "age".to_string(),
1290 }),
1291 op: BinaryOp::Ne,
1292 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1293 },
1294 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1295 variable: "n".to_string(),
1296 label: Some("Person".to_string()),
1297 input: None,
1298 })),
1299 });
1300
1301 let cardinality = estimator.estimate(&filter);
1302 assert!(cardinality > 900.0);
1304 }
1305
1306 #[test]
1307 fn test_filter_range_selectivity() {
1308 let mut estimator = CardinalityEstimator::new();
1309 estimator.add_table_stats("Person", TableStats::new(1000));
1310
1311 let filter = LogicalOperator::Filter(FilterOp {
1312 predicate: LogicalExpression::Binary {
1313 left: Box::new(LogicalExpression::Property {
1314 variable: "n".to_string(),
1315 property: "age".to_string(),
1316 }),
1317 op: BinaryOp::Gt,
1318 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1319 },
1320 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1321 variable: "n".to_string(),
1322 label: Some("Person".to_string()),
1323 input: None,
1324 })),
1325 });
1326
1327 let cardinality = estimator.estimate(&filter);
1328 assert!(cardinality < 500.0);
1330 assert!(cardinality > 100.0);
1331 }
1332
1333 #[test]
1334 fn test_filter_and_selectivity() {
1335 let mut estimator = CardinalityEstimator::new();
1336 estimator.add_table_stats("Person", TableStats::new(1000));
1337
1338 let filter = LogicalOperator::Filter(FilterOp {
1341 predicate: LogicalExpression::Binary {
1342 left: Box::new(LogicalExpression::Binary {
1343 left: Box::new(LogicalExpression::Property {
1344 variable: "n".to_string(),
1345 property: "city".to_string(),
1346 }),
1347 op: BinaryOp::Eq,
1348 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1349 }),
1350 op: BinaryOp::And,
1351 right: Box::new(LogicalExpression::Binary {
1352 left: Box::new(LogicalExpression::Property {
1353 variable: "n".to_string(),
1354 property: "age".to_string(),
1355 }),
1356 op: BinaryOp::Eq,
1357 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1358 }),
1359 },
1360 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1361 variable: "n".to_string(),
1362 label: Some("Person".to_string()),
1363 input: None,
1364 })),
1365 });
1366
1367 let cardinality = estimator.estimate(&filter);
1368 assert!(cardinality < 100.0);
1371 assert!(cardinality >= 1.0);
1372 }
1373
1374 #[test]
1375 fn test_filter_or_selectivity() {
1376 let mut estimator = CardinalityEstimator::new();
1377 estimator.add_table_stats("Person", TableStats::new(1000));
1378
1379 let filter = LogicalOperator::Filter(FilterOp {
1383 predicate: LogicalExpression::Binary {
1384 left: Box::new(LogicalExpression::Binary {
1385 left: Box::new(LogicalExpression::Property {
1386 variable: "n".to_string(),
1387 property: "city".to_string(),
1388 }),
1389 op: BinaryOp::Eq,
1390 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1391 }),
1392 op: BinaryOp::Or,
1393 right: Box::new(LogicalExpression::Binary {
1394 left: Box::new(LogicalExpression::Property {
1395 variable: "n".to_string(),
1396 property: "city".to_string(),
1397 }),
1398 op: BinaryOp::Eq,
1399 right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1400 }),
1401 },
1402 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1403 variable: "n".to_string(),
1404 label: Some("Person".to_string()),
1405 input: None,
1406 })),
1407 });
1408
1409 let cardinality = estimator.estimate(&filter);
1410 assert!(cardinality < 100.0);
1412 assert!(cardinality >= 1.0);
1413 }
1414
1415 #[test]
1416 fn test_filter_literal_true() {
1417 let mut estimator = CardinalityEstimator::new();
1418 estimator.add_table_stats("Person", TableStats::new(1000));
1419
1420 let filter = LogicalOperator::Filter(FilterOp {
1421 predicate: LogicalExpression::Literal(Value::Bool(true)),
1422 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1423 variable: "n".to_string(),
1424 label: Some("Person".to_string()),
1425 input: None,
1426 })),
1427 });
1428
1429 let cardinality = estimator.estimate(&filter);
1430 assert!((cardinality - 1000.0).abs() < 0.001);
1432 }
1433
1434 #[test]
1435 fn test_filter_literal_false() {
1436 let mut estimator = CardinalityEstimator::new();
1437 estimator.add_table_stats("Person", TableStats::new(1000));
1438
1439 let filter = LogicalOperator::Filter(FilterOp {
1440 predicate: LogicalExpression::Literal(Value::Bool(false)),
1441 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1442 variable: "n".to_string(),
1443 label: Some("Person".to_string()),
1444 input: None,
1445 })),
1446 });
1447
1448 let cardinality = estimator.estimate(&filter);
1449 assert!((cardinality - 1.0).abs() < 0.001);
1451 }
1452
1453 #[test]
1454 fn test_unary_not_selectivity() {
1455 let mut estimator = CardinalityEstimator::new();
1456 estimator.add_table_stats("Person", TableStats::new(1000));
1457
1458 let filter = LogicalOperator::Filter(FilterOp {
1459 predicate: LogicalExpression::Unary {
1460 op: UnaryOp::Not,
1461 operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1462 },
1463 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1464 variable: "n".to_string(),
1465 label: Some("Person".to_string()),
1466 input: None,
1467 })),
1468 });
1469
1470 let cardinality = estimator.estimate(&filter);
1471 assert!(cardinality < 1000.0);
1473 }
1474
1475 #[test]
1476 fn test_unary_is_null_selectivity() {
1477 let mut estimator = CardinalityEstimator::new();
1478 estimator.add_table_stats("Person", TableStats::new(1000));
1479
1480 let filter = LogicalOperator::Filter(FilterOp {
1481 predicate: LogicalExpression::Unary {
1482 op: UnaryOp::IsNull,
1483 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1484 },
1485 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1486 variable: "n".to_string(),
1487 label: Some("Person".to_string()),
1488 input: None,
1489 })),
1490 });
1491
1492 let cardinality = estimator.estimate(&filter);
1493 assert!(cardinality < 100.0);
1495 }
1496
1497 #[test]
1498 fn test_expand_cardinality() {
1499 let mut estimator = CardinalityEstimator::new();
1500 estimator.add_table_stats("Person", TableStats::new(100));
1501
1502 let expand = LogicalOperator::Expand(ExpandOp {
1503 from_variable: "a".to_string(),
1504 to_variable: "b".to_string(),
1505 edge_variable: None,
1506 direction: ExpandDirection::Outgoing,
1507 edge_type: None,
1508 min_hops: 1,
1509 max_hops: Some(1),
1510 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1511 variable: "a".to_string(),
1512 label: Some("Person".to_string()),
1513 input: None,
1514 })),
1515 path_alias: None,
1516 });
1517
1518 let cardinality = estimator.estimate(&expand);
1519 assert!(cardinality > 100.0);
1521 }
1522
1523 #[test]
1524 fn test_expand_with_edge_type_filter() {
1525 let mut estimator = CardinalityEstimator::new();
1526 estimator.add_table_stats("Person", TableStats::new(100));
1527
1528 let expand = LogicalOperator::Expand(ExpandOp {
1529 from_variable: "a".to_string(),
1530 to_variable: "b".to_string(),
1531 edge_variable: None,
1532 direction: ExpandDirection::Outgoing,
1533 edge_type: Some("KNOWS".to_string()),
1534 min_hops: 1,
1535 max_hops: Some(1),
1536 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1537 variable: "a".to_string(),
1538 label: Some("Person".to_string()),
1539 input: None,
1540 })),
1541 path_alias: None,
1542 });
1543
1544 let cardinality = estimator.estimate(&expand);
1545 assert!(cardinality > 100.0);
1547 }
1548
1549 #[test]
1550 fn test_expand_variable_length() {
1551 let mut estimator = CardinalityEstimator::new();
1552 estimator.add_table_stats("Person", TableStats::new(100));
1553
1554 let expand = LogicalOperator::Expand(ExpandOp {
1555 from_variable: "a".to_string(),
1556 to_variable: "b".to_string(),
1557 edge_variable: None,
1558 direction: ExpandDirection::Outgoing,
1559 edge_type: None,
1560 min_hops: 1,
1561 max_hops: Some(3),
1562 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1563 variable: "a".to_string(),
1564 label: Some("Person".to_string()),
1565 input: None,
1566 })),
1567 path_alias: None,
1568 });
1569
1570 let cardinality = estimator.estimate(&expand);
1571 assert!(cardinality > 500.0);
1573 }
1574
1575 #[test]
1576 fn test_join_cross_product() {
1577 let mut estimator = CardinalityEstimator::new();
1578 estimator.add_table_stats("Person", TableStats::new(100));
1579 estimator.add_table_stats("Company", TableStats::new(50));
1580
1581 let join = LogicalOperator::Join(JoinOp {
1582 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1583 variable: "p".to_string(),
1584 label: Some("Person".to_string()),
1585 input: None,
1586 })),
1587 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1588 variable: "c".to_string(),
1589 label: Some("Company".to_string()),
1590 input: None,
1591 })),
1592 join_type: JoinType::Cross,
1593 conditions: vec![],
1594 });
1595
1596 let cardinality = estimator.estimate(&join);
1597 assert!((cardinality - 5000.0).abs() < 0.001);
1599 }
1600
1601 #[test]
1602 fn test_join_left_outer() {
1603 let mut estimator = CardinalityEstimator::new();
1604 estimator.add_table_stats("Person", TableStats::new(1000));
1605 estimator.add_table_stats("Company", TableStats::new(10));
1606
1607 let join = LogicalOperator::Join(JoinOp {
1608 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1609 variable: "p".to_string(),
1610 label: Some("Person".to_string()),
1611 input: None,
1612 })),
1613 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1614 variable: "c".to_string(),
1615 label: Some("Company".to_string()),
1616 input: None,
1617 })),
1618 join_type: JoinType::Left,
1619 conditions: vec![JoinCondition {
1620 left: LogicalExpression::Variable("p".to_string()),
1621 right: LogicalExpression::Variable("c".to_string()),
1622 }],
1623 });
1624
1625 let cardinality = estimator.estimate(&join);
1626 assert!(cardinality >= 1000.0);
1628 }
1629
1630 #[test]
1631 fn test_join_semi() {
1632 let mut estimator = CardinalityEstimator::new();
1633 estimator.add_table_stats("Person", TableStats::new(1000));
1634 estimator.add_table_stats("Company", TableStats::new(100));
1635
1636 let join = LogicalOperator::Join(JoinOp {
1637 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1638 variable: "p".to_string(),
1639 label: Some("Person".to_string()),
1640 input: None,
1641 })),
1642 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1643 variable: "c".to_string(),
1644 label: Some("Company".to_string()),
1645 input: None,
1646 })),
1647 join_type: JoinType::Semi,
1648 conditions: vec![],
1649 });
1650
1651 let cardinality = estimator.estimate(&join);
1652 assert!(cardinality <= 1000.0);
1654 }
1655
1656 #[test]
1657 fn test_join_anti() {
1658 let mut estimator = CardinalityEstimator::new();
1659 estimator.add_table_stats("Person", TableStats::new(1000));
1660 estimator.add_table_stats("Company", TableStats::new(100));
1661
1662 let join = LogicalOperator::Join(JoinOp {
1663 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1664 variable: "p".to_string(),
1665 label: Some("Person".to_string()),
1666 input: None,
1667 })),
1668 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1669 variable: "c".to_string(),
1670 label: Some("Company".to_string()),
1671 input: None,
1672 })),
1673 join_type: JoinType::Anti,
1674 conditions: vec![],
1675 });
1676
1677 let cardinality = estimator.estimate(&join);
1678 assert!(cardinality <= 1000.0);
1680 assert!(cardinality >= 1.0);
1681 }
1682
1683 #[test]
1684 fn test_project_preserves_cardinality() {
1685 let mut estimator = CardinalityEstimator::new();
1686 estimator.add_table_stats("Person", TableStats::new(1000));
1687
1688 let project = LogicalOperator::Project(ProjectOp {
1689 projections: vec![Projection {
1690 expression: LogicalExpression::Variable("n".to_string()),
1691 alias: None,
1692 }],
1693 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1694 variable: "n".to_string(),
1695 label: Some("Person".to_string()),
1696 input: None,
1697 })),
1698 });
1699
1700 let cardinality = estimator.estimate(&project);
1701 assert!((cardinality - 1000.0).abs() < 0.001);
1702 }
1703
1704 #[test]
1705 fn test_sort_preserves_cardinality() {
1706 let mut estimator = CardinalityEstimator::new();
1707 estimator.add_table_stats("Person", TableStats::new(1000));
1708
1709 let sort = LogicalOperator::Sort(SortOp {
1710 keys: vec![SortKey {
1711 expression: LogicalExpression::Variable("n".to_string()),
1712 order: SortOrder::Ascending,
1713 }],
1714 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1715 variable: "n".to_string(),
1716 label: Some("Person".to_string()),
1717 input: None,
1718 })),
1719 });
1720
1721 let cardinality = estimator.estimate(&sort);
1722 assert!((cardinality - 1000.0).abs() < 0.001);
1723 }
1724
1725 #[test]
1726 fn test_distinct_reduces_cardinality() {
1727 let mut estimator = CardinalityEstimator::new();
1728 estimator.add_table_stats("Person", TableStats::new(1000));
1729
1730 let distinct = LogicalOperator::Distinct(DistinctOp {
1731 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1732 variable: "n".to_string(),
1733 label: Some("Person".to_string()),
1734 input: None,
1735 })),
1736 columns: None,
1737 });
1738
1739 let cardinality = estimator.estimate(&distinct);
1740 assert!((cardinality - 500.0).abs() < 0.001);
1742 }
1743
1744 #[test]
1745 fn test_skip_reduces_cardinality() {
1746 let mut estimator = CardinalityEstimator::new();
1747 estimator.add_table_stats("Person", TableStats::new(1000));
1748
1749 let skip = LogicalOperator::Skip(SkipOp {
1750 count: 100,
1751 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1752 variable: "n".to_string(),
1753 label: Some("Person".to_string()),
1754 input: None,
1755 })),
1756 });
1757
1758 let cardinality = estimator.estimate(&skip);
1759 assert!((cardinality - 900.0).abs() < 0.001);
1760 }
1761
1762 #[test]
1763 fn test_return_preserves_cardinality() {
1764 let mut estimator = CardinalityEstimator::new();
1765 estimator.add_table_stats("Person", TableStats::new(1000));
1766
1767 let ret = LogicalOperator::Return(ReturnOp {
1768 items: vec![ReturnItem {
1769 expression: LogicalExpression::Variable("n".to_string()),
1770 alias: None,
1771 }],
1772 distinct: false,
1773 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1774 variable: "n".to_string(),
1775 label: Some("Person".to_string()),
1776 input: None,
1777 })),
1778 });
1779
1780 let cardinality = estimator.estimate(&ret);
1781 assert!((cardinality - 1000.0).abs() < 0.001);
1782 }
1783
1784 #[test]
1785 fn test_empty_cardinality() {
1786 let estimator = CardinalityEstimator::new();
1787 let cardinality = estimator.estimate(&LogicalOperator::Empty);
1788 assert!((cardinality).abs() < 0.001);
1789 }
1790
1791 #[test]
1792 fn test_table_stats_with_column() {
1793 let stats = TableStats::new(1000).with_column(
1794 "age",
1795 ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1796 );
1797
1798 assert_eq!(stats.row_count, 1000);
1799 let col = stats.columns.get("age").unwrap();
1800 assert_eq!(col.distinct_count, 50);
1801 assert_eq!(col.null_count, 10);
1802 assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1803 assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1804 }
1805
1806 #[test]
1807 fn test_estimator_default() {
1808 let estimator = CardinalityEstimator::default();
1809 let scan = LogicalOperator::NodeScan(NodeScanOp {
1810 variable: "n".to_string(),
1811 label: None,
1812 input: None,
1813 });
1814 let cardinality = estimator.estimate(&scan);
1815 assert!((cardinality - 1000.0).abs() < 0.001);
1816 }
1817
1818 #[test]
1819 fn test_set_avg_fanout() {
1820 let mut estimator = CardinalityEstimator::new();
1821 estimator.add_table_stats("Person", TableStats::new(100));
1822 estimator.set_avg_fanout(5.0);
1823
1824 let expand = LogicalOperator::Expand(ExpandOp {
1825 from_variable: "a".to_string(),
1826 to_variable: "b".to_string(),
1827 edge_variable: None,
1828 direction: ExpandDirection::Outgoing,
1829 edge_type: None,
1830 min_hops: 1,
1831 max_hops: Some(1),
1832 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1833 variable: "a".to_string(),
1834 label: Some("Person".to_string()),
1835 input: None,
1836 })),
1837 path_alias: None,
1838 });
1839
1840 let cardinality = estimator.estimate(&expand);
1841 assert!((cardinality - 500.0).abs() < 0.001);
1843 }
1844
1845 #[test]
1846 fn test_multiple_group_by_keys_reduce_cardinality() {
1847 let mut estimator = CardinalityEstimator::new();
1851 estimator.add_table_stats("Person", TableStats::new(10000));
1852
1853 let single_group = LogicalOperator::Aggregate(AggregateOp {
1854 group_by: vec![LogicalExpression::Property {
1855 variable: "n".to_string(),
1856 property: "city".to_string(),
1857 }],
1858 aggregates: vec![],
1859 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1860 variable: "n".to_string(),
1861 label: Some("Person".to_string()),
1862 input: None,
1863 })),
1864 having: None,
1865 });
1866
1867 let multi_group = LogicalOperator::Aggregate(AggregateOp {
1868 group_by: vec![
1869 LogicalExpression::Property {
1870 variable: "n".to_string(),
1871 property: "city".to_string(),
1872 },
1873 LogicalExpression::Property {
1874 variable: "n".to_string(),
1875 property: "country".to_string(),
1876 },
1877 ],
1878 aggregates: vec![],
1879 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1880 variable: "n".to_string(),
1881 label: Some("Person".to_string()),
1882 input: None,
1883 })),
1884 having: None,
1885 });
1886
1887 let single_card = estimator.estimate(&single_group);
1888 let multi_card = estimator.estimate(&multi_group);
1889
1890 assert!(single_card < 10000.0);
1892 assert!(multi_card < 10000.0);
1893 assert!(single_card >= 1.0);
1895 assert!(multi_card >= 1.0);
1896 }
1897
1898 #[test]
1901 fn test_histogram_build_uniform() {
1902 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1904 let histogram = EquiDepthHistogram::build(&values, 10);
1905
1906 assert_eq!(histogram.num_buckets(), 10);
1907 assert_eq!(histogram.total_rows(), 100);
1908
1909 for bucket in histogram.buckets() {
1911 assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1912 }
1913 }
1914
1915 #[test]
1916 fn test_histogram_build_skewed() {
1917 let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1919 values.extend((0..20).map(|i| 1000.0 + i as f64));
1920 let histogram = EquiDepthHistogram::build(&values, 5);
1921
1922 assert_eq!(histogram.num_buckets(), 5);
1923 assert_eq!(histogram.total_rows(), 100);
1924
1925 for bucket in histogram.buckets() {
1927 assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1928 }
1929 }
1930
1931 #[test]
1932 fn test_histogram_range_selectivity_full() {
1933 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1934 let histogram = EquiDepthHistogram::build(&values, 10);
1935
1936 let selectivity = histogram.range_selectivity(None, None);
1938 assert!((selectivity - 1.0).abs() < 0.01);
1939 }
1940
1941 #[test]
1942 fn test_histogram_range_selectivity_half() {
1943 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1944 let histogram = EquiDepthHistogram::build(&values, 10);
1945
1946 let selectivity = histogram.range_selectivity(Some(50.0), None);
1948 assert!(selectivity > 0.4 && selectivity < 0.6);
1949 }
1950
1951 #[test]
1952 fn test_histogram_range_selectivity_quarter() {
1953 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1954 let histogram = EquiDepthHistogram::build(&values, 10);
1955
1956 let selectivity = histogram.range_selectivity(None, Some(25.0));
1958 assert!(selectivity > 0.2 && selectivity < 0.3);
1959 }
1960
1961 #[test]
1962 fn test_histogram_equality_selectivity() {
1963 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1964 let histogram = EquiDepthHistogram::build(&values, 10);
1965
1966 let selectivity = histogram.equality_selectivity(50.0);
1968 assert!(selectivity > 0.005 && selectivity < 0.02);
1969 }
1970
1971 #[test]
1972 fn test_histogram_empty() {
1973 let histogram = EquiDepthHistogram::build(&[], 10);
1974
1975 assert_eq!(histogram.num_buckets(), 0);
1976 assert_eq!(histogram.total_rows(), 0);
1977
1978 let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1980 assert!((selectivity - 0.33).abs() < 0.01);
1981 }
1982
1983 #[test]
1984 fn test_histogram_bucket_overlap() {
1985 let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1986
1987 assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1989
1990 assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1992
1993 assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1995
1996 assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1998
1999 assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
2001 }
2002
2003 #[test]
2004 fn test_column_stats_from_values() {
2005 let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2006 let stats = ColumnStats::from_values(values, 4);
2007
2008 assert_eq!(stats.distinct_count, 5); assert!(stats.min_value.is_some());
2010 assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2011 assert!(stats.max_value.is_some());
2012 assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2013 assert!(stats.histogram.is_some());
2014 }
2015
2016 #[test]
2017 fn test_column_stats_with_histogram_builder() {
2018 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2019 let histogram = EquiDepthHistogram::build(&values, 10);
2020
2021 let stats = ColumnStats::new(100)
2022 .with_range(0.0, 99.0)
2023 .with_histogram(histogram);
2024
2025 assert!(stats.histogram.is_some());
2026 assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2027 }
2028
2029 #[test]
2030 fn test_filter_with_histogram_stats() {
2031 let mut estimator = CardinalityEstimator::new();
2032
2033 let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2035 let histogram = EquiDepthHistogram::build(&age_values, 10);
2036 let age_stats = ColumnStats::new(62)
2037 .with_range(18.0, 79.0)
2038 .with_histogram(histogram);
2039
2040 estimator.add_table_stats(
2041 "Person",
2042 TableStats::new(1000).with_column("age", age_stats),
2043 );
2044
2045 let filter = LogicalOperator::Filter(FilterOp {
2048 predicate: LogicalExpression::Binary {
2049 left: Box::new(LogicalExpression::Property {
2050 variable: "n".to_string(),
2051 property: "age".to_string(),
2052 }),
2053 op: BinaryOp::Gt,
2054 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2055 },
2056 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2057 variable: "n".to_string(),
2058 label: Some("Person".to_string()),
2059 input: None,
2060 })),
2061 });
2062
2063 let cardinality = estimator.estimate(&filter);
2064
2065 assert!(cardinality > 300.0 && cardinality < 600.0);
2068 }
2069
2070 #[test]
2071 fn test_filter_equality_with_histogram() {
2072 let mut estimator = CardinalityEstimator::new();
2073
2074 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2076 let histogram = EquiDepthHistogram::build(&values, 10);
2077 let stats = ColumnStats::new(100)
2078 .with_range(0.0, 99.0)
2079 .with_histogram(histogram);
2080
2081 estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2082
2083 let filter = LogicalOperator::Filter(FilterOp {
2085 predicate: LogicalExpression::Binary {
2086 left: Box::new(LogicalExpression::Property {
2087 variable: "d".to_string(),
2088 property: "value".to_string(),
2089 }),
2090 op: BinaryOp::Eq,
2091 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2092 },
2093 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2094 variable: "d".to_string(),
2095 label: Some("Data".to_string()),
2096 input: None,
2097 })),
2098 });
2099
2100 let cardinality = estimator.estimate(&filter);
2101
2102 assert!((1.0..50.0).contains(&cardinality));
2105 }
2106
2107 #[test]
2108 fn test_histogram_min_max() {
2109 let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2110 let histogram = EquiDepthHistogram::build(&values, 2);
2111
2112 assert_eq!(histogram.min_value(), Some(5.0));
2113 assert!(histogram.max_value().is_some());
2115 }
2116
2117 #[test]
2120 fn test_selectivity_config_defaults() {
2121 let config = SelectivityConfig::new();
2122 assert!((config.default - 0.1).abs() < f64::EPSILON);
2123 assert!((config.equality - 0.01).abs() < f64::EPSILON);
2124 assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2125 assert!((config.range - 0.33).abs() < f64::EPSILON);
2126 assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2127 assert!((config.membership - 0.1).abs() < f64::EPSILON);
2128 assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2129 assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2130 assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2131 }
2132
2133 #[test]
2134 fn test_custom_selectivity_config() {
2135 let config = SelectivityConfig {
2136 equality: 0.05,
2137 range: 0.25,
2138 ..SelectivityConfig::new()
2139 };
2140 let estimator = CardinalityEstimator::with_selectivity_config(config);
2141 assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2142 assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2143 }
2144
2145 #[test]
2146 fn test_custom_selectivity_affects_estimation() {
2147 let mut default_est = CardinalityEstimator::new();
2149 default_est.add_table_stats("Person", TableStats::new(1000));
2150
2151 let filter = LogicalOperator::Filter(FilterOp {
2152 predicate: LogicalExpression::Binary {
2153 left: Box::new(LogicalExpression::Property {
2154 variable: "n".to_string(),
2155 property: "name".to_string(),
2156 }),
2157 op: BinaryOp::Eq,
2158 right: Box::new(LogicalExpression::Literal(Value::String("Alice".into()))),
2159 },
2160 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2161 variable: "n".to_string(),
2162 label: Some("Person".to_string()),
2163 input: None,
2164 })),
2165 });
2166
2167 let default_card = default_est.estimate(&filter);
2168
2169 let config = SelectivityConfig {
2171 equality: 0.2,
2172 ..SelectivityConfig::new()
2173 };
2174 let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2175 custom_est.add_table_stats("Person", TableStats::new(1000));
2176
2177 let custom_card = custom_est.estimate(&filter);
2178
2179 assert!(custom_card > default_card);
2180 assert!((custom_card - 200.0).abs() < 1.0);
2181 }
2182
2183 #[test]
2184 fn test_custom_range_selectivity() {
2185 let config = SelectivityConfig {
2186 range: 0.5,
2187 ..SelectivityConfig::new()
2188 };
2189 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2190 estimator.add_table_stats("Person", TableStats::new(1000));
2191
2192 let filter = LogicalOperator::Filter(FilterOp {
2193 predicate: LogicalExpression::Binary {
2194 left: Box::new(LogicalExpression::Property {
2195 variable: "n".to_string(),
2196 property: "age".to_string(),
2197 }),
2198 op: BinaryOp::Gt,
2199 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2200 },
2201 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2202 variable: "n".to_string(),
2203 label: Some("Person".to_string()),
2204 input: None,
2205 })),
2206 });
2207
2208 let cardinality = estimator.estimate(&filter);
2209 assert!((cardinality - 500.0).abs() < 1.0);
2211 }
2212
2213 #[test]
2214 fn test_custom_distinct_fraction() {
2215 let config = SelectivityConfig {
2216 distinct_fraction: 0.8,
2217 ..SelectivityConfig::new()
2218 };
2219 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2220 estimator.add_table_stats("Person", TableStats::new(1000));
2221
2222 let distinct = LogicalOperator::Distinct(DistinctOp {
2223 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2224 variable: "n".to_string(),
2225 label: Some("Person".to_string()),
2226 input: None,
2227 })),
2228 columns: None,
2229 });
2230
2231 let cardinality = estimator.estimate(&distinct);
2232 assert!((cardinality - 800.0).abs() < 1.0);
2234 }
2235
2236 #[test]
2239 fn test_estimation_log_basic() {
2240 let mut log = EstimationLog::new(10.0);
2241 log.record("NodeScan(Person)", 1000.0, 1200.0);
2242 log.record("Filter(age > 30)", 100.0, 90.0);
2243
2244 assert_eq!(log.entries().len(), 2);
2245 assert!(!log.should_replan()); }
2247
2248 #[test]
2249 fn test_estimation_log_triggers_replan() {
2250 let mut log = EstimationLog::new(10.0);
2251 log.record("NodeScan(Person)", 100.0, 5000.0); assert!(log.should_replan());
2254 }
2255
2256 #[test]
2257 fn test_estimation_log_overestimate_triggers_replan() {
2258 let mut log = EstimationLog::new(5.0);
2259 log.record("Filter", 1000.0, 100.0); assert!(log.should_replan()); }
2263
2264 #[test]
2265 fn test_estimation_entry_error_ratio() {
2266 let entry = EstimationEntry {
2267 operator: "test".into(),
2268 estimated: 100.0,
2269 actual: 200.0,
2270 };
2271 assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2272
2273 let perfect = EstimationEntry {
2274 operator: "test".into(),
2275 estimated: 100.0,
2276 actual: 100.0,
2277 };
2278 assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2279
2280 let zero_est = EstimationEntry {
2281 operator: "test".into(),
2282 estimated: 0.0,
2283 actual: 0.0,
2284 };
2285 assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2286 }
2287
2288 #[test]
2289 fn test_estimation_log_max_error_ratio() {
2290 let mut log = EstimationLog::new(10.0);
2291 log.record("A", 100.0, 300.0); log.record("B", 100.0, 50.0); log.record("C", 100.0, 100.0); assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2296 }
2297
2298 #[test]
2299 fn test_estimation_log_clear() {
2300 let mut log = EstimationLog::new(10.0);
2301 log.record("A", 100.0, 100.0);
2302 assert_eq!(log.entries().len(), 1);
2303
2304 log.clear();
2305 assert!(log.entries().is_empty());
2306 assert!(!log.should_replan());
2307 }
2308
2309 #[test]
2310 fn test_create_estimation_log() {
2311 let log = CardinalityEstimator::create_estimation_log();
2312 assert!(log.entries().is_empty());
2313 assert!(!log.should_replan());
2314 }
2315}