1use crate::query::plan::{
18 AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19 LogicalExpression, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, SkipOp, SortOp,
20 UnaryOp, VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31 pub lower_bound: f64,
33 pub upper_bound: f64,
35 pub frequency: u64,
37 pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42 #[must_use]
44 pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45 Self {
46 lower_bound,
47 upper_bound,
48 frequency,
49 distinct_count,
50 }
51 }
52
53 #[must_use]
55 pub fn width(&self) -> f64 {
56 self.upper_bound - self.lower_bound
57 }
58
59 #[must_use]
61 pub fn contains(&self, value: f64) -> bool {
62 value >= self.lower_bound && value < self.upper_bound
63 }
64
65 #[must_use]
67 pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68 let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69 let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71 let bucket_width = self.width();
72 if bucket_width <= 0.0 {
73 return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74 1.0
75 } else {
76 0.0
77 };
78 }
79
80 let overlap = (effective_upper - effective_lower).max(0.0);
81 (overlap / bucket_width).min(1.0)
82 }
83}
84
85#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106 buckets: Vec<HistogramBucket>,
108 total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113 #[must_use]
115 pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116 let total_rows = buckets.iter().map(|b| b.frequency).sum();
117 Self {
118 buckets,
119 total_rows,
120 }
121 }
122
123 #[must_use]
132 pub fn build(values: &[f64], num_buckets: usize) -> Self {
133 if values.is_empty() || num_buckets == 0 {
134 return Self {
135 buckets: Vec::new(),
136 total_rows: 0,
137 };
138 }
139
140 let num_buckets = num_buckets.min(values.len());
141 let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142 let mut buckets = Vec::with_capacity(num_buckets);
143
144 let mut start_idx = 0;
145 while start_idx < values.len() {
146 let end_idx = (start_idx + rows_per_bucket).min(values.len());
147 let lower_bound = values[start_idx];
148 let upper_bound = if end_idx < values.len() {
149 values[end_idx]
150 } else {
151 values[end_idx - 1] + 1.0
153 };
154
155 let bucket_values = &values[start_idx..end_idx];
157 let distinct_count = count_distinct(bucket_values);
158
159 buckets.push(HistogramBucket::new(
160 lower_bound,
161 upper_bound,
162 (end_idx - start_idx) as u64,
163 distinct_count,
164 ));
165
166 start_idx = end_idx;
167 }
168
169 Self::new(buckets)
170 }
171
172 #[must_use]
174 pub fn num_buckets(&self) -> usize {
175 self.buckets.len()
176 }
177
178 #[must_use]
180 pub fn total_rows(&self) -> u64 {
181 self.total_rows
182 }
183
184 #[must_use]
186 pub fn buckets(&self) -> &[HistogramBucket] {
187 &self.buckets
188 }
189
190 #[must_use]
199 pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200 if self.buckets.is_empty() || self.total_rows == 0 {
201 return 0.33; }
203
204 let mut matching_rows = 0.0;
205
206 for bucket in &self.buckets {
207 let bucket_lower = bucket.lower_bound;
209 let bucket_upper = bucket.upper_bound;
210
211 if let Some(l) = lower
213 && bucket_upper <= l
214 {
215 continue;
216 }
217 if let Some(u) = upper
218 && bucket_lower >= u
219 {
220 continue;
221 }
222
223 let overlap = bucket.overlap_fraction(lower, upper);
225 matching_rows += overlap * bucket.frequency as f64;
226 }
227
228 (matching_rows / self.total_rows as f64).clamp(0.0, 1.0)
229 }
230
231 #[must_use]
235 pub fn equality_selectivity(&self, value: f64) -> f64 {
236 if self.buckets.is_empty() || self.total_rows == 0 {
237 return 0.01; }
239
240 for bucket in &self.buckets {
242 if bucket.contains(value) {
243 if bucket.distinct_count > 0 {
245 return (bucket.frequency as f64
246 / bucket.distinct_count as f64
247 / self.total_rows as f64)
248 .min(1.0);
249 }
250 }
251 }
252
253 0.001
255 }
256
257 #[must_use]
259 pub fn min_value(&self) -> Option<f64> {
260 self.buckets.first().map(|b| b.lower_bound)
261 }
262
263 #[must_use]
265 pub fn max_value(&self) -> Option<f64> {
266 self.buckets.last().map(|b| b.upper_bound)
267 }
268}
269
270fn count_distinct(sorted_values: &[f64]) -> u64 {
272 if sorted_values.is_empty() {
273 return 0;
274 }
275
276 let mut count = 1u64;
277 let mut prev = sorted_values[0];
278
279 for &val in &sorted_values[1..] {
280 if (val - prev).abs() > f64::EPSILON {
281 count += 1;
282 prev = val;
283 }
284 }
285
286 count
287}
288
289#[derive(Debug, Clone)]
291pub struct TableStats {
292 pub row_count: u64,
294 pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299 #[must_use]
301 pub fn new(row_count: u64) -> Self {
302 Self {
303 row_count,
304 columns: HashMap::new(),
305 }
306 }
307
308 pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310 self.columns.insert(name.to_string(), stats);
311 self
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct ColumnStats {
318 pub distinct_count: u64,
320 pub null_count: u64,
322 pub min_value: Option<f64>,
324 pub max_value: Option<f64>,
326 pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331 #[must_use]
333 pub fn new(distinct_count: u64) -> Self {
334 Self {
335 distinct_count,
336 null_count: 0,
337 min_value: None,
338 max_value: None,
339 histogram: None,
340 }
341 }
342
343 #[must_use]
345 pub fn with_nulls(mut self, null_count: u64) -> Self {
346 self.null_count = null_count;
347 self
348 }
349
350 #[must_use]
352 pub fn with_range(mut self, min: f64, max: f64) -> Self {
353 self.min_value = Some(min);
354 self.max_value = Some(max);
355 self
356 }
357
358 #[must_use]
360 pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361 self.histogram = Some(histogram);
362 self
363 }
364
365 #[must_use]
373 pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374 if values.is_empty() {
375 return Self::new(0);
376 }
377
378 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381 let min = values.first().copied();
382 let max = values.last().copied();
383 let distinct_count = count_distinct(&values);
384 let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386 Self {
387 distinct_count,
388 null_count: 0,
389 min_value: min,
390 max_value: max,
391 histogram: Some(histogram),
392 }
393 }
394}
395
396#[derive(Debug, Clone)]
402pub struct SelectivityConfig {
403 pub default: f64,
405 pub equality: f64,
407 pub inequality: f64,
409 pub range: f64,
411 pub string_ops: f64,
413 pub membership: f64,
415 pub is_null: f64,
417 pub is_not_null: f64,
419 pub distinct_fraction: f64,
421}
422
423impl SelectivityConfig {
424 #[must_use]
426 pub fn new() -> Self {
427 Self {
428 default: 0.1,
429 equality: 0.01,
430 inequality: 0.99,
431 range: 0.33,
432 string_ops: 0.1,
433 membership: 0.1,
434 is_null: 0.05,
435 is_not_null: 0.95,
436 distinct_fraction: 0.5,
437 }
438 }
439}
440
441impl Default for SelectivityConfig {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447#[derive(Debug, Clone)]
449pub struct EstimationEntry {
450 pub operator: String,
452 pub estimated: f64,
454 pub actual: f64,
456}
457
458impl EstimationEntry {
459 #[must_use]
465 pub fn error_ratio(&self) -> f64 {
466 if self.estimated.abs() < f64::EPSILON {
467 if self.actual.abs() < f64::EPSILON {
468 1.0
469 } else {
470 f64::INFINITY
471 }
472 } else {
473 self.actual / self.estimated
474 }
475 }
476}
477
478#[derive(Debug, Clone, Default)]
485pub struct EstimationLog {
486 entries: Vec<EstimationEntry>,
488 replan_threshold: f64,
492}
493
494impl EstimationLog {
495 #[must_use]
497 pub fn new(replan_threshold: f64) -> Self {
498 Self {
499 entries: Vec::new(),
500 replan_threshold,
501 }
502 }
503
504 pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
506 self.entries.push(EstimationEntry {
507 operator: operator.into(),
508 estimated,
509 actual,
510 });
511 }
512
513 #[must_use]
515 pub fn entries(&self) -> &[EstimationEntry] {
516 &self.entries
517 }
518
519 #[must_use]
522 pub fn should_replan(&self) -> bool {
523 self.entries.iter().any(|e| {
524 let ratio = e.error_ratio();
525 ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
526 })
527 }
528
529 #[must_use]
531 pub fn max_error_ratio(&self) -> f64 {
532 self.entries
533 .iter()
534 .map(|e| {
535 let r = e.error_ratio();
536 if r < 1.0 { 1.0 / r } else { r }
538 })
539 .fold(1.0_f64, f64::max)
540 }
541
542 pub fn clear(&mut self) {
544 self.entries.clear();
545 }
546}
547
548pub struct CardinalityEstimator {
550 table_stats: HashMap<String, TableStats>,
552 default_row_count: u64,
554 default_selectivity: f64,
556 avg_fanout: f64,
558 selectivity_config: SelectivityConfig,
560}
561
562impl CardinalityEstimator {
563 #[must_use]
565 pub fn new() -> Self {
566 let config = SelectivityConfig::new();
567 Self {
568 table_stats: HashMap::new(),
569 default_row_count: 1000,
570 default_selectivity: config.default,
571 avg_fanout: 10.0,
572 selectivity_config: config,
573 }
574 }
575
576 #[must_use]
578 pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
579 Self {
580 table_stats: HashMap::new(),
581 default_row_count: 1000,
582 default_selectivity: config.default,
583 avg_fanout: 10.0,
584 selectivity_config: config,
585 }
586 }
587
588 #[must_use]
590 pub fn selectivity_config(&self) -> &SelectivityConfig {
591 &self.selectivity_config
592 }
593
594 #[must_use]
596 pub fn create_estimation_log() -> EstimationLog {
597 EstimationLog::new(10.0)
598 }
599
600 #[must_use]
606 pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
607 let mut estimator = Self::new();
608
609 if stats.total_nodes > 0 {
611 estimator.default_row_count = stats.total_nodes;
612 }
613
614 for (label, label_stats) in &stats.labels {
616 let mut table_stats = TableStats::new(label_stats.node_count);
617
618 for (prop, col_stats) in &label_stats.properties {
620 let optimizer_col =
621 ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
622 table_stats = table_stats.with_column(prop, optimizer_col);
623 }
624
625 estimator.add_table_stats(label, table_stats);
626 }
627
628 if !stats.edge_types.is_empty() {
630 let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
631 estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
632 } else if stats.total_nodes > 0 {
633 estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
634 }
635
636 if estimator.avg_fanout < 1.0 {
638 estimator.avg_fanout = 1.0;
639 }
640
641 estimator
642 }
643
644 pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
646 self.table_stats.insert(name.to_string(), stats);
647 }
648
649 pub fn set_avg_fanout(&mut self, fanout: f64) {
651 self.avg_fanout = fanout;
652 }
653
654 #[must_use]
656 pub fn estimate(&self, op: &LogicalOperator) -> f64 {
657 match op {
658 LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
659 LogicalOperator::Filter(filter) => self.estimate_filter(filter),
660 LogicalOperator::Project(project) => self.estimate_project(project),
661 LogicalOperator::Expand(expand) => self.estimate_expand(expand),
662 LogicalOperator::Join(join) => self.estimate_join(join),
663 LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
664 LogicalOperator::Sort(sort) => self.estimate_sort(sort),
665 LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
666 LogicalOperator::Limit(limit) => self.estimate_limit(limit),
667 LogicalOperator::Skip(skip) => self.estimate_skip(skip),
668 LogicalOperator::Return(ret) => self.estimate(&ret.input),
669 LogicalOperator::Empty => 0.0,
670 LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
671 LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
672 LogicalOperator::MultiWayJoin(mwj) => self.estimate_multi_way_join(mwj),
673 _ => self.default_row_count as f64,
674 }
675 }
676
677 fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
679 if let Some(label) = &scan.label
680 && let Some(stats) = self.table_stats.get(label)
681 {
682 return stats.row_count as f64;
683 }
684 self.default_row_count as f64
686 }
687
688 fn estimate_filter(&self, filter: &FilterOp) -> f64 {
690 let input_cardinality = self.estimate(&filter.input);
691 let selectivity = self.estimate_selectivity(&filter.predicate);
692 (input_cardinality * selectivity).max(1.0)
693 }
694
695 fn estimate_project(&self, project: &ProjectOp) -> f64 {
697 self.estimate(&project.input)
698 }
699
700 fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
702 let input_cardinality = self.estimate(&expand.input);
703
704 let fanout = if !expand.edge_types.is_empty() {
706 self.avg_fanout * 0.5
708 } else {
709 self.avg_fanout
710 };
711
712 let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
714 let min = expand.min_hops as f64;
715 let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
716 (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
718 } else {
719 fanout
720 };
721
722 (input_cardinality * path_multiplier).max(1.0)
723 }
724
725 fn estimate_join(&self, join: &JoinOp) -> f64 {
727 let left_card = self.estimate(&join.left);
728 let right_card = self.estimate(&join.right);
729
730 match join.join_type {
731 JoinType::Cross => left_card * right_card,
732 JoinType::Inner => {
733 let selectivity = if join.conditions.is_empty() {
735 1.0 } else {
737 0.1_f64.powi(join.conditions.len() as i32)
739 };
740 (left_card * right_card * selectivity).max(1.0)
741 }
742 JoinType::Left => {
743 let inner_card = self.estimate_join(&JoinOp {
745 left: join.left.clone(),
746 right: join.right.clone(),
747 join_type: JoinType::Inner,
748 conditions: join.conditions.clone(),
749 });
750 inner_card.max(left_card)
751 }
752 JoinType::Right => {
753 let inner_card = self.estimate_join(&JoinOp {
755 left: join.left.clone(),
756 right: join.right.clone(),
757 join_type: JoinType::Inner,
758 conditions: join.conditions.clone(),
759 });
760 inner_card.max(right_card)
761 }
762 JoinType::Full => {
763 let inner_card = self.estimate_join(&JoinOp {
765 left: join.left.clone(),
766 right: join.right.clone(),
767 join_type: JoinType::Inner,
768 conditions: join.conditions.clone(),
769 });
770 inner_card.max(left_card.max(right_card))
771 }
772 JoinType::Semi => {
773 (left_card * self.default_selectivity).max(1.0)
775 }
776 JoinType::Anti => {
777 (left_card * (1.0 - self.default_selectivity)).max(1.0)
779 }
780 }
781 }
782
783 fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
785 let input_cardinality = self.estimate(&agg.input);
786
787 if agg.group_by.is_empty() {
788 1.0
790 } else {
791 let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
794 (input_cardinality / group_reduction).max(1.0)
795 }
796 }
797
798 fn estimate_sort(&self, sort: &SortOp) -> f64 {
800 self.estimate(&sort.input)
801 }
802
803 fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
805 let input_cardinality = self.estimate(&distinct.input);
806 (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
807 }
808
809 fn estimate_limit(&self, limit: &LimitOp) -> f64 {
811 let input_cardinality = self.estimate(&limit.input);
812 (limit.count as f64).min(input_cardinality)
813 }
814
815 fn estimate_skip(&self, skip: &SkipOp) -> f64 {
817 let input_cardinality = self.estimate(&skip.input);
818 (input_cardinality - skip.count as f64).max(0.0)
819 }
820
821 fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
826 let base_k = scan.k as f64;
827
828 let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
830 0.7
832 } else {
833 1.0
834 };
835
836 (base_k * selectivity).max(1.0)
837 }
838
839 fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
843 let input_cardinality = self.estimate(&join.input);
844 let k = join.k as f64;
845
846 let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
848 0.7
849 } else {
850 1.0
851 };
852
853 (input_cardinality * k * selectivity).max(1.0)
854 }
855
856 fn estimate_multi_way_join(&self, mwj: &MultiWayJoinOp) -> f64 {
861 if mwj.inputs.is_empty() {
862 return 0.0;
863 }
864 let cardinalities: Vec<f64> = mwj
865 .inputs
866 .iter()
867 .map(|input| self.estimate(input))
868 .collect();
869 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
870 let n = cardinalities.len() as f64;
871 (min_card.powf(n / 2.0)).max(1.0)
873 }
874
875 fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
877 match expr {
878 LogicalExpression::Binary { left, op, right } => {
879 self.estimate_binary_selectivity(left, *op, right)
880 }
881 LogicalExpression::Unary { op, operand } => {
882 self.estimate_unary_selectivity(*op, operand)
883 }
884 LogicalExpression::Literal(value) => {
885 if let grafeo_common::types::Value::Bool(b) = value {
887 if *b { 1.0 } else { 0.0 }
888 } else {
889 self.default_selectivity
890 }
891 }
892 _ => self.default_selectivity,
893 }
894 }
895
896 fn estimate_binary_selectivity(
898 &self,
899 left: &LogicalExpression,
900 op: BinaryOp,
901 right: &LogicalExpression,
902 ) -> f64 {
903 match op {
904 BinaryOp::Eq => {
906 if let Some(selectivity) = self.try_equality_selectivity(left, right) {
907 return selectivity;
908 }
909 self.selectivity_config.equality
910 }
911 BinaryOp::Ne => self.selectivity_config.inequality,
913 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
915 if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
916 return selectivity;
917 }
918 self.selectivity_config.range
919 }
920 BinaryOp::And => {
922 let left_sel = self.estimate_selectivity(left);
923 let right_sel = self.estimate_selectivity(right);
924 left_sel * right_sel
926 }
927 BinaryOp::Or => {
928 let left_sel = self.estimate_selectivity(left);
929 let right_sel = self.estimate_selectivity(right);
930 (left_sel + right_sel - left_sel * right_sel).min(1.0)
933 }
934 BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
936 self.selectivity_config.string_ops
937 }
938 BinaryOp::In => self.selectivity_config.membership,
940 _ => self.default_selectivity,
942 }
943 }
944
945 fn try_equality_selectivity(
947 &self,
948 left: &LogicalExpression,
949 right: &LogicalExpression,
950 ) -> Option<f64> {
951 let (label, column, value) = self.extract_column_and_value(left, right)?;
953
954 let stats = self.get_column_stats(&label, &column)?;
956
957 if let Some(ref histogram) = stats.histogram {
959 return Some(histogram.equality_selectivity(value));
960 }
961
962 if stats.distinct_count > 0 {
964 return Some(1.0 / stats.distinct_count as f64);
965 }
966
967 None
968 }
969
970 fn try_range_selectivity(
972 &self,
973 left: &LogicalExpression,
974 op: BinaryOp,
975 right: &LogicalExpression,
976 ) -> Option<f64> {
977 let (label, column, value) = self.extract_column_and_value(left, right)?;
979
980 let stats = self.get_column_stats(&label, &column)?;
982
983 let (lower, upper) = match op {
985 BinaryOp::Lt => (None, Some(value)),
986 BinaryOp::Le => (None, Some(value + f64::EPSILON)),
987 BinaryOp::Gt => (Some(value + f64::EPSILON), None),
988 BinaryOp::Ge => (Some(value), None),
989 _ => return None,
990 };
991
992 if let Some(ref histogram) = stats.histogram {
994 return Some(histogram.range_selectivity(lower, upper));
995 }
996
997 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
999 let range = max - min;
1000 if range <= 0.0 {
1001 return Some(1.0);
1002 }
1003
1004 let effective_lower = lower.unwrap_or(min).max(min);
1005 let effective_upper = upper.unwrap_or(max).min(max);
1006 let overlap = (effective_upper - effective_lower).max(0.0);
1007 return Some((overlap / range).clamp(0.0, 1.0));
1008 }
1009
1010 None
1011 }
1012
1013 fn extract_column_and_value(
1018 &self,
1019 left: &LogicalExpression,
1020 right: &LogicalExpression,
1021 ) -> Option<(String, String, f64)> {
1022 if let Some(result) = self.try_extract_property_literal(left, right) {
1024 return Some(result);
1025 }
1026
1027 self.try_extract_property_literal(right, left)
1029 }
1030
1031 fn try_extract_property_literal(
1033 &self,
1034 property_expr: &LogicalExpression,
1035 literal_expr: &LogicalExpression,
1036 ) -> Option<(String, String, f64)> {
1037 let (variable, property) = match property_expr {
1039 LogicalExpression::Property { variable, property } => {
1040 (variable.clone(), property.clone())
1041 }
1042 _ => return None,
1043 };
1044
1045 let value = match literal_expr {
1047 LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1048 LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1049 _ => return None,
1050 };
1051
1052 for label in self.table_stats.keys() {
1056 if let Some(stats) = self.table_stats.get(label)
1057 && stats.columns.contains_key(&property)
1058 {
1059 return Some((label.clone(), property, value));
1060 }
1061 }
1062
1063 Some((variable, property, value))
1065 }
1066
1067 fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1069 match op {
1070 UnaryOp::Not => 1.0 - self.default_selectivity,
1071 UnaryOp::IsNull => self.selectivity_config.is_null,
1072 UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1073 UnaryOp::Neg => 1.0, }
1075 }
1076
1077 fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1079 self.table_stats.get(label)?.columns.get(column)
1080 }
1081}
1082
1083impl Default for CardinalityEstimator {
1084 fn default() -> Self {
1085 Self::new()
1086 }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091 use super::*;
1092 use crate::query::plan::{
1093 DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1094 ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1095 };
1096 use grafeo_common::types::Value;
1097
1098 #[test]
1099 fn test_node_scan_with_stats() {
1100 let mut estimator = CardinalityEstimator::new();
1101 estimator.add_table_stats("Person", TableStats::new(5000));
1102
1103 let scan = LogicalOperator::NodeScan(NodeScanOp {
1104 variable: "n".to_string(),
1105 label: Some("Person".to_string()),
1106 input: None,
1107 });
1108
1109 let cardinality = estimator.estimate(&scan);
1110 assert!((cardinality - 5000.0).abs() < 0.001);
1111 }
1112
1113 #[test]
1114 fn test_filter_reduces_cardinality() {
1115 let mut estimator = CardinalityEstimator::new();
1116 estimator.add_table_stats("Person", TableStats::new(1000));
1117
1118 let filter = LogicalOperator::Filter(FilterOp {
1119 predicate: LogicalExpression::Binary {
1120 left: Box::new(LogicalExpression::Property {
1121 variable: "n".to_string(),
1122 property: "age".to_string(),
1123 }),
1124 op: BinaryOp::Eq,
1125 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1126 },
1127 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1128 variable: "n".to_string(),
1129 label: Some("Person".to_string()),
1130 input: None,
1131 })),
1132 pushdown_hint: None,
1133 });
1134
1135 let cardinality = estimator.estimate(&filter);
1136 assert!(cardinality < 1000.0);
1138 assert!(cardinality >= 1.0);
1139 }
1140
1141 #[test]
1142 fn test_join_cardinality() {
1143 let mut estimator = CardinalityEstimator::new();
1144 estimator.add_table_stats("Person", TableStats::new(1000));
1145 estimator.add_table_stats("Company", TableStats::new(100));
1146
1147 let join = LogicalOperator::Join(JoinOp {
1148 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1149 variable: "p".to_string(),
1150 label: Some("Person".to_string()),
1151 input: None,
1152 })),
1153 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1154 variable: "c".to_string(),
1155 label: Some("Company".to_string()),
1156 input: None,
1157 })),
1158 join_type: JoinType::Inner,
1159 conditions: vec![JoinCondition {
1160 left: LogicalExpression::Property {
1161 variable: "p".to_string(),
1162 property: "company_id".to_string(),
1163 },
1164 right: LogicalExpression::Property {
1165 variable: "c".to_string(),
1166 property: "id".to_string(),
1167 },
1168 }],
1169 });
1170
1171 let cardinality = estimator.estimate(&join);
1172 assert!(cardinality < 1000.0 * 100.0);
1174 }
1175
1176 #[test]
1177 fn test_limit_caps_cardinality() {
1178 let mut estimator = CardinalityEstimator::new();
1179 estimator.add_table_stats("Person", TableStats::new(1000));
1180
1181 let limit = LogicalOperator::Limit(LimitOp {
1182 count: 10,
1183 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1184 variable: "n".to_string(),
1185 label: Some("Person".to_string()),
1186 input: None,
1187 })),
1188 });
1189
1190 let cardinality = estimator.estimate(&limit);
1191 assert!((cardinality - 10.0).abs() < 0.001);
1192 }
1193
1194 #[test]
1195 fn test_aggregate_reduces_cardinality() {
1196 let mut estimator = CardinalityEstimator::new();
1197 estimator.add_table_stats("Person", TableStats::new(1000));
1198
1199 let global_agg = LogicalOperator::Aggregate(AggregateOp {
1201 group_by: vec![],
1202 aggregates: vec![],
1203 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1204 variable: "n".to_string(),
1205 label: Some("Person".to_string()),
1206 input: None,
1207 })),
1208 having: None,
1209 });
1210
1211 let cardinality = estimator.estimate(&global_agg);
1212 assert!((cardinality - 1.0).abs() < 0.001);
1213
1214 let group_agg = LogicalOperator::Aggregate(AggregateOp {
1216 group_by: vec![LogicalExpression::Property {
1217 variable: "n".to_string(),
1218 property: "city".to_string(),
1219 }],
1220 aggregates: vec![],
1221 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1222 variable: "n".to_string(),
1223 label: Some("Person".to_string()),
1224 input: None,
1225 })),
1226 having: None,
1227 });
1228
1229 let cardinality = estimator.estimate(&group_agg);
1230 assert!(cardinality < 1000.0);
1232 }
1233
1234 #[test]
1235 fn test_node_scan_without_stats() {
1236 let estimator = CardinalityEstimator::new();
1237
1238 let scan = LogicalOperator::NodeScan(NodeScanOp {
1239 variable: "n".to_string(),
1240 label: Some("Unknown".to_string()),
1241 input: None,
1242 });
1243
1244 let cardinality = estimator.estimate(&scan);
1245 assert!((cardinality - 1000.0).abs() < 0.001);
1247 }
1248
1249 #[test]
1250 fn test_node_scan_no_label() {
1251 let estimator = CardinalityEstimator::new();
1252
1253 let scan = LogicalOperator::NodeScan(NodeScanOp {
1254 variable: "n".to_string(),
1255 label: None,
1256 input: None,
1257 });
1258
1259 let cardinality = estimator.estimate(&scan);
1260 assert!((cardinality - 1000.0).abs() < 0.001);
1262 }
1263
1264 #[test]
1265 fn test_filter_inequality_selectivity() {
1266 let mut estimator = CardinalityEstimator::new();
1267 estimator.add_table_stats("Person", TableStats::new(1000));
1268
1269 let filter = LogicalOperator::Filter(FilterOp {
1270 predicate: LogicalExpression::Binary {
1271 left: Box::new(LogicalExpression::Property {
1272 variable: "n".to_string(),
1273 property: "age".to_string(),
1274 }),
1275 op: BinaryOp::Ne,
1276 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1277 },
1278 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1279 variable: "n".to_string(),
1280 label: Some("Person".to_string()),
1281 input: None,
1282 })),
1283 pushdown_hint: None,
1284 });
1285
1286 let cardinality = estimator.estimate(&filter);
1287 assert!(cardinality > 900.0);
1289 }
1290
1291 #[test]
1292 fn test_filter_range_selectivity() {
1293 let mut estimator = CardinalityEstimator::new();
1294 estimator.add_table_stats("Person", TableStats::new(1000));
1295
1296 let filter = LogicalOperator::Filter(FilterOp {
1297 predicate: LogicalExpression::Binary {
1298 left: Box::new(LogicalExpression::Property {
1299 variable: "n".to_string(),
1300 property: "age".to_string(),
1301 }),
1302 op: BinaryOp::Gt,
1303 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1304 },
1305 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1306 variable: "n".to_string(),
1307 label: Some("Person".to_string()),
1308 input: None,
1309 })),
1310 pushdown_hint: None,
1311 });
1312
1313 let cardinality = estimator.estimate(&filter);
1314 assert!(cardinality < 500.0);
1316 assert!(cardinality > 100.0);
1317 }
1318
1319 #[test]
1320 fn test_filter_and_selectivity() {
1321 let mut estimator = CardinalityEstimator::new();
1322 estimator.add_table_stats("Person", TableStats::new(1000));
1323
1324 let filter = LogicalOperator::Filter(FilterOp {
1327 predicate: LogicalExpression::Binary {
1328 left: Box::new(LogicalExpression::Binary {
1329 left: Box::new(LogicalExpression::Property {
1330 variable: "n".to_string(),
1331 property: "city".to_string(),
1332 }),
1333 op: BinaryOp::Eq,
1334 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1335 }),
1336 op: BinaryOp::And,
1337 right: Box::new(LogicalExpression::Binary {
1338 left: Box::new(LogicalExpression::Property {
1339 variable: "n".to_string(),
1340 property: "age".to_string(),
1341 }),
1342 op: BinaryOp::Eq,
1343 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1344 }),
1345 },
1346 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1347 variable: "n".to_string(),
1348 label: Some("Person".to_string()),
1349 input: None,
1350 })),
1351 pushdown_hint: None,
1352 });
1353
1354 let cardinality = estimator.estimate(&filter);
1355 assert!(cardinality < 100.0);
1358 assert!(cardinality >= 1.0);
1359 }
1360
1361 #[test]
1362 fn test_filter_or_selectivity() {
1363 let mut estimator = CardinalityEstimator::new();
1364 estimator.add_table_stats("Person", TableStats::new(1000));
1365
1366 let filter = LogicalOperator::Filter(FilterOp {
1370 predicate: LogicalExpression::Binary {
1371 left: Box::new(LogicalExpression::Binary {
1372 left: Box::new(LogicalExpression::Property {
1373 variable: "n".to_string(),
1374 property: "city".to_string(),
1375 }),
1376 op: BinaryOp::Eq,
1377 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1378 }),
1379 op: BinaryOp::Or,
1380 right: Box::new(LogicalExpression::Binary {
1381 left: Box::new(LogicalExpression::Property {
1382 variable: "n".to_string(),
1383 property: "city".to_string(),
1384 }),
1385 op: BinaryOp::Eq,
1386 right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1387 }),
1388 },
1389 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1390 variable: "n".to_string(),
1391 label: Some("Person".to_string()),
1392 input: None,
1393 })),
1394 pushdown_hint: None,
1395 });
1396
1397 let cardinality = estimator.estimate(&filter);
1398 assert!(cardinality < 100.0);
1400 assert!(cardinality >= 1.0);
1401 }
1402
1403 #[test]
1404 fn test_filter_literal_true() {
1405 let mut estimator = CardinalityEstimator::new();
1406 estimator.add_table_stats("Person", TableStats::new(1000));
1407
1408 let filter = LogicalOperator::Filter(FilterOp {
1409 predicate: LogicalExpression::Literal(Value::Bool(true)),
1410 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1411 variable: "n".to_string(),
1412 label: Some("Person".to_string()),
1413 input: None,
1414 })),
1415 pushdown_hint: None,
1416 });
1417
1418 let cardinality = estimator.estimate(&filter);
1419 assert!((cardinality - 1000.0).abs() < 0.001);
1421 }
1422
1423 #[test]
1424 fn test_filter_literal_false() {
1425 let mut estimator = CardinalityEstimator::new();
1426 estimator.add_table_stats("Person", TableStats::new(1000));
1427
1428 let filter = LogicalOperator::Filter(FilterOp {
1429 predicate: LogicalExpression::Literal(Value::Bool(false)),
1430 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1431 variable: "n".to_string(),
1432 label: Some("Person".to_string()),
1433 input: None,
1434 })),
1435 pushdown_hint: None,
1436 });
1437
1438 let cardinality = estimator.estimate(&filter);
1439 assert!((cardinality - 1.0).abs() < 0.001);
1441 }
1442
1443 #[test]
1444 fn test_unary_not_selectivity() {
1445 let mut estimator = CardinalityEstimator::new();
1446 estimator.add_table_stats("Person", TableStats::new(1000));
1447
1448 let filter = LogicalOperator::Filter(FilterOp {
1449 predicate: LogicalExpression::Unary {
1450 op: UnaryOp::Not,
1451 operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1452 },
1453 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1454 variable: "n".to_string(),
1455 label: Some("Person".to_string()),
1456 input: None,
1457 })),
1458 pushdown_hint: None,
1459 });
1460
1461 let cardinality = estimator.estimate(&filter);
1462 assert!(cardinality < 1000.0);
1464 }
1465
1466 #[test]
1467 fn test_unary_is_null_selectivity() {
1468 let mut estimator = CardinalityEstimator::new();
1469 estimator.add_table_stats("Person", TableStats::new(1000));
1470
1471 let filter = LogicalOperator::Filter(FilterOp {
1472 predicate: LogicalExpression::Unary {
1473 op: UnaryOp::IsNull,
1474 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1475 },
1476 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1477 variable: "n".to_string(),
1478 label: Some("Person".to_string()),
1479 input: None,
1480 })),
1481 pushdown_hint: None,
1482 });
1483
1484 let cardinality = estimator.estimate(&filter);
1485 assert!(cardinality < 100.0);
1487 }
1488
1489 #[test]
1490 fn test_expand_cardinality() {
1491 let mut estimator = CardinalityEstimator::new();
1492 estimator.add_table_stats("Person", TableStats::new(100));
1493
1494 let expand = LogicalOperator::Expand(ExpandOp {
1495 from_variable: "a".to_string(),
1496 to_variable: "b".to_string(),
1497 edge_variable: None,
1498 direction: ExpandDirection::Outgoing,
1499 edge_types: vec![],
1500 min_hops: 1,
1501 max_hops: Some(1),
1502 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1503 variable: "a".to_string(),
1504 label: Some("Person".to_string()),
1505 input: None,
1506 })),
1507 path_alias: None,
1508 path_mode: PathMode::Walk,
1509 });
1510
1511 let cardinality = estimator.estimate(&expand);
1512 assert!(cardinality > 100.0);
1514 }
1515
1516 #[test]
1517 fn test_expand_with_edge_type_filter() {
1518 let mut estimator = CardinalityEstimator::new();
1519 estimator.add_table_stats("Person", TableStats::new(100));
1520
1521 let expand = LogicalOperator::Expand(ExpandOp {
1522 from_variable: "a".to_string(),
1523 to_variable: "b".to_string(),
1524 edge_variable: None,
1525 direction: ExpandDirection::Outgoing,
1526 edge_types: vec!["KNOWS".to_string()],
1527 min_hops: 1,
1528 max_hops: Some(1),
1529 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1530 variable: "a".to_string(),
1531 label: Some("Person".to_string()),
1532 input: None,
1533 })),
1534 path_alias: None,
1535 path_mode: PathMode::Walk,
1536 });
1537
1538 let cardinality = estimator.estimate(&expand);
1539 assert!(cardinality > 100.0);
1541 }
1542
1543 #[test]
1544 fn test_expand_variable_length() {
1545 let mut estimator = CardinalityEstimator::new();
1546 estimator.add_table_stats("Person", TableStats::new(100));
1547
1548 let expand = LogicalOperator::Expand(ExpandOp {
1549 from_variable: "a".to_string(),
1550 to_variable: "b".to_string(),
1551 edge_variable: None,
1552 direction: ExpandDirection::Outgoing,
1553 edge_types: vec![],
1554 min_hops: 1,
1555 max_hops: Some(3),
1556 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1557 variable: "a".to_string(),
1558 label: Some("Person".to_string()),
1559 input: None,
1560 })),
1561 path_alias: None,
1562 path_mode: PathMode::Walk,
1563 });
1564
1565 let cardinality = estimator.estimate(&expand);
1566 assert!(cardinality > 500.0);
1568 }
1569
1570 #[test]
1571 fn test_join_cross_product() {
1572 let mut estimator = CardinalityEstimator::new();
1573 estimator.add_table_stats("Person", TableStats::new(100));
1574 estimator.add_table_stats("Company", TableStats::new(50));
1575
1576 let join = LogicalOperator::Join(JoinOp {
1577 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1578 variable: "p".to_string(),
1579 label: Some("Person".to_string()),
1580 input: None,
1581 })),
1582 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1583 variable: "c".to_string(),
1584 label: Some("Company".to_string()),
1585 input: None,
1586 })),
1587 join_type: JoinType::Cross,
1588 conditions: vec![],
1589 });
1590
1591 let cardinality = estimator.estimate(&join);
1592 assert!((cardinality - 5000.0).abs() < 0.001);
1594 }
1595
1596 #[test]
1597 fn test_join_left_outer() {
1598 let mut estimator = CardinalityEstimator::new();
1599 estimator.add_table_stats("Person", TableStats::new(1000));
1600 estimator.add_table_stats("Company", TableStats::new(10));
1601
1602 let join = LogicalOperator::Join(JoinOp {
1603 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1604 variable: "p".to_string(),
1605 label: Some("Person".to_string()),
1606 input: None,
1607 })),
1608 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1609 variable: "c".to_string(),
1610 label: Some("Company".to_string()),
1611 input: None,
1612 })),
1613 join_type: JoinType::Left,
1614 conditions: vec![JoinCondition {
1615 left: LogicalExpression::Variable("p".to_string()),
1616 right: LogicalExpression::Variable("c".to_string()),
1617 }],
1618 });
1619
1620 let cardinality = estimator.estimate(&join);
1621 assert!(cardinality >= 1000.0);
1623 }
1624
1625 #[test]
1626 fn test_join_semi() {
1627 let mut estimator = CardinalityEstimator::new();
1628 estimator.add_table_stats("Person", TableStats::new(1000));
1629 estimator.add_table_stats("Company", TableStats::new(100));
1630
1631 let join = LogicalOperator::Join(JoinOp {
1632 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1633 variable: "p".to_string(),
1634 label: Some("Person".to_string()),
1635 input: None,
1636 })),
1637 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1638 variable: "c".to_string(),
1639 label: Some("Company".to_string()),
1640 input: None,
1641 })),
1642 join_type: JoinType::Semi,
1643 conditions: vec![],
1644 });
1645
1646 let cardinality = estimator.estimate(&join);
1647 assert!(cardinality <= 1000.0);
1649 }
1650
1651 #[test]
1652 fn test_join_anti() {
1653 let mut estimator = CardinalityEstimator::new();
1654 estimator.add_table_stats("Person", TableStats::new(1000));
1655 estimator.add_table_stats("Company", TableStats::new(100));
1656
1657 let join = LogicalOperator::Join(JoinOp {
1658 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1659 variable: "p".to_string(),
1660 label: Some("Person".to_string()),
1661 input: None,
1662 })),
1663 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1664 variable: "c".to_string(),
1665 label: Some("Company".to_string()),
1666 input: None,
1667 })),
1668 join_type: JoinType::Anti,
1669 conditions: vec![],
1670 });
1671
1672 let cardinality = estimator.estimate(&join);
1673 assert!(cardinality <= 1000.0);
1675 assert!(cardinality >= 1.0);
1676 }
1677
1678 #[test]
1679 fn test_project_preserves_cardinality() {
1680 let mut estimator = CardinalityEstimator::new();
1681 estimator.add_table_stats("Person", TableStats::new(1000));
1682
1683 let project = LogicalOperator::Project(ProjectOp {
1684 projections: vec![Projection {
1685 expression: LogicalExpression::Variable("n".to_string()),
1686 alias: None,
1687 }],
1688 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1689 variable: "n".to_string(),
1690 label: Some("Person".to_string()),
1691 input: None,
1692 })),
1693 });
1694
1695 let cardinality = estimator.estimate(&project);
1696 assert!((cardinality - 1000.0).abs() < 0.001);
1697 }
1698
1699 #[test]
1700 fn test_sort_preserves_cardinality() {
1701 let mut estimator = CardinalityEstimator::new();
1702 estimator.add_table_stats("Person", TableStats::new(1000));
1703
1704 let sort = LogicalOperator::Sort(SortOp {
1705 keys: vec![SortKey {
1706 expression: LogicalExpression::Variable("n".to_string()),
1707 order: SortOrder::Ascending,
1708 nulls: None,
1709 }],
1710 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1711 variable: "n".to_string(),
1712 label: Some("Person".to_string()),
1713 input: None,
1714 })),
1715 });
1716
1717 let cardinality = estimator.estimate(&sort);
1718 assert!((cardinality - 1000.0).abs() < 0.001);
1719 }
1720
1721 #[test]
1722 fn test_distinct_reduces_cardinality() {
1723 let mut estimator = CardinalityEstimator::new();
1724 estimator.add_table_stats("Person", TableStats::new(1000));
1725
1726 let distinct = LogicalOperator::Distinct(DistinctOp {
1727 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1728 variable: "n".to_string(),
1729 label: Some("Person".to_string()),
1730 input: None,
1731 })),
1732 columns: None,
1733 });
1734
1735 let cardinality = estimator.estimate(&distinct);
1736 assert!((cardinality - 500.0).abs() < 0.001);
1738 }
1739
1740 #[test]
1741 fn test_skip_reduces_cardinality() {
1742 let mut estimator = CardinalityEstimator::new();
1743 estimator.add_table_stats("Person", TableStats::new(1000));
1744
1745 let skip = LogicalOperator::Skip(SkipOp {
1746 count: 100,
1747 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1748 variable: "n".to_string(),
1749 label: Some("Person".to_string()),
1750 input: None,
1751 })),
1752 });
1753
1754 let cardinality = estimator.estimate(&skip);
1755 assert!((cardinality - 900.0).abs() < 0.001);
1756 }
1757
1758 #[test]
1759 fn test_return_preserves_cardinality() {
1760 let mut estimator = CardinalityEstimator::new();
1761 estimator.add_table_stats("Person", TableStats::new(1000));
1762
1763 let ret = LogicalOperator::Return(ReturnOp {
1764 items: vec![ReturnItem {
1765 expression: LogicalExpression::Variable("n".to_string()),
1766 alias: None,
1767 }],
1768 distinct: false,
1769 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1770 variable: "n".to_string(),
1771 label: Some("Person".to_string()),
1772 input: None,
1773 })),
1774 });
1775
1776 let cardinality = estimator.estimate(&ret);
1777 assert!((cardinality - 1000.0).abs() < 0.001);
1778 }
1779
1780 #[test]
1781 fn test_empty_cardinality() {
1782 let estimator = CardinalityEstimator::new();
1783 let cardinality = estimator.estimate(&LogicalOperator::Empty);
1784 assert!((cardinality).abs() < 0.001);
1785 }
1786
1787 #[test]
1788 fn test_table_stats_with_column() {
1789 let stats = TableStats::new(1000).with_column(
1790 "age",
1791 ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1792 );
1793
1794 assert_eq!(stats.row_count, 1000);
1795 let col = stats.columns.get("age").unwrap();
1796 assert_eq!(col.distinct_count, 50);
1797 assert_eq!(col.null_count, 10);
1798 assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1799 assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1800 }
1801
1802 #[test]
1803 fn test_estimator_default() {
1804 let estimator = CardinalityEstimator::default();
1805 let scan = LogicalOperator::NodeScan(NodeScanOp {
1806 variable: "n".to_string(),
1807 label: None,
1808 input: None,
1809 });
1810 let cardinality = estimator.estimate(&scan);
1811 assert!((cardinality - 1000.0).abs() < 0.001);
1812 }
1813
1814 #[test]
1815 fn test_set_avg_fanout() {
1816 let mut estimator = CardinalityEstimator::new();
1817 estimator.add_table_stats("Person", TableStats::new(100));
1818 estimator.set_avg_fanout(5.0);
1819
1820 let expand = LogicalOperator::Expand(ExpandOp {
1821 from_variable: "a".to_string(),
1822 to_variable: "b".to_string(),
1823 edge_variable: None,
1824 direction: ExpandDirection::Outgoing,
1825 edge_types: vec![],
1826 min_hops: 1,
1827 max_hops: Some(1),
1828 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1829 variable: "a".to_string(),
1830 label: Some("Person".to_string()),
1831 input: None,
1832 })),
1833 path_alias: None,
1834 path_mode: PathMode::Walk,
1835 });
1836
1837 let cardinality = estimator.estimate(&expand);
1838 assert!((cardinality - 500.0).abs() < 0.001);
1840 }
1841
1842 #[test]
1843 fn test_multiple_group_by_keys_reduce_cardinality() {
1844 let mut estimator = CardinalityEstimator::new();
1848 estimator.add_table_stats("Person", TableStats::new(10000));
1849
1850 let single_group = LogicalOperator::Aggregate(AggregateOp {
1851 group_by: vec![LogicalExpression::Property {
1852 variable: "n".to_string(),
1853 property: "city".to_string(),
1854 }],
1855 aggregates: vec![],
1856 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1857 variable: "n".to_string(),
1858 label: Some("Person".to_string()),
1859 input: None,
1860 })),
1861 having: None,
1862 });
1863
1864 let multi_group = LogicalOperator::Aggregate(AggregateOp {
1865 group_by: vec![
1866 LogicalExpression::Property {
1867 variable: "n".to_string(),
1868 property: "city".to_string(),
1869 },
1870 LogicalExpression::Property {
1871 variable: "n".to_string(),
1872 property: "country".to_string(),
1873 },
1874 ],
1875 aggregates: vec![],
1876 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1877 variable: "n".to_string(),
1878 label: Some("Person".to_string()),
1879 input: None,
1880 })),
1881 having: None,
1882 });
1883
1884 let single_card = estimator.estimate(&single_group);
1885 let multi_card = estimator.estimate(&multi_group);
1886
1887 assert!(single_card < 10000.0);
1889 assert!(multi_card < 10000.0);
1890 assert!(single_card >= 1.0);
1892 assert!(multi_card >= 1.0);
1893 }
1894
1895 #[test]
1898 fn test_histogram_build_uniform() {
1899 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1901 let histogram = EquiDepthHistogram::build(&values, 10);
1902
1903 assert_eq!(histogram.num_buckets(), 10);
1904 assert_eq!(histogram.total_rows(), 100);
1905
1906 for bucket in histogram.buckets() {
1908 assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1909 }
1910 }
1911
1912 #[test]
1913 fn test_histogram_build_skewed() {
1914 let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1916 values.extend((0..20).map(|i| 1000.0 + i as f64));
1917 let histogram = EquiDepthHistogram::build(&values, 5);
1918
1919 assert_eq!(histogram.num_buckets(), 5);
1920 assert_eq!(histogram.total_rows(), 100);
1921
1922 for bucket in histogram.buckets() {
1924 assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1925 }
1926 }
1927
1928 #[test]
1929 fn test_histogram_range_selectivity_full() {
1930 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1931 let histogram = EquiDepthHistogram::build(&values, 10);
1932
1933 let selectivity = histogram.range_selectivity(None, None);
1935 assert!((selectivity - 1.0).abs() < 0.01);
1936 }
1937
1938 #[test]
1939 fn test_histogram_range_selectivity_half() {
1940 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1941 let histogram = EquiDepthHistogram::build(&values, 10);
1942
1943 let selectivity = histogram.range_selectivity(Some(50.0), None);
1945 assert!(selectivity > 0.4 && selectivity < 0.6);
1946 }
1947
1948 #[test]
1949 fn test_histogram_range_selectivity_quarter() {
1950 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1951 let histogram = EquiDepthHistogram::build(&values, 10);
1952
1953 let selectivity = histogram.range_selectivity(None, Some(25.0));
1955 assert!(selectivity > 0.2 && selectivity < 0.3);
1956 }
1957
1958 #[test]
1959 fn test_histogram_equality_selectivity() {
1960 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1961 let histogram = EquiDepthHistogram::build(&values, 10);
1962
1963 let selectivity = histogram.equality_selectivity(50.0);
1965 assert!(selectivity > 0.005 && selectivity < 0.02);
1966 }
1967
1968 #[test]
1969 fn test_histogram_empty() {
1970 let histogram = EquiDepthHistogram::build(&[], 10);
1971
1972 assert_eq!(histogram.num_buckets(), 0);
1973 assert_eq!(histogram.total_rows(), 0);
1974
1975 let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1977 assert!((selectivity - 0.33).abs() < 0.01);
1978 }
1979
1980 #[test]
1981 fn test_histogram_bucket_overlap() {
1982 let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1983
1984 assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1986
1987 assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1989
1990 assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1992
1993 assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1995
1996 assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1998 }
1999
2000 #[test]
2001 fn test_column_stats_from_values() {
2002 let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2003 let stats = ColumnStats::from_values(values, 4);
2004
2005 assert_eq!(stats.distinct_count, 5); assert!(stats.min_value.is_some());
2007 assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2008 assert!(stats.max_value.is_some());
2009 assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2010 assert!(stats.histogram.is_some());
2011 }
2012
2013 #[test]
2014 fn test_column_stats_with_histogram_builder() {
2015 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2016 let histogram = EquiDepthHistogram::build(&values, 10);
2017
2018 let stats = ColumnStats::new(100)
2019 .with_range(0.0, 99.0)
2020 .with_histogram(histogram);
2021
2022 assert!(stats.histogram.is_some());
2023 assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2024 }
2025
2026 #[test]
2027 fn test_filter_with_histogram_stats() {
2028 let mut estimator = CardinalityEstimator::new();
2029
2030 let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2032 let histogram = EquiDepthHistogram::build(&age_values, 10);
2033 let age_stats = ColumnStats::new(62)
2034 .with_range(18.0, 79.0)
2035 .with_histogram(histogram);
2036
2037 estimator.add_table_stats(
2038 "Person",
2039 TableStats::new(1000).with_column("age", age_stats),
2040 );
2041
2042 let filter = LogicalOperator::Filter(FilterOp {
2045 predicate: LogicalExpression::Binary {
2046 left: Box::new(LogicalExpression::Property {
2047 variable: "n".to_string(),
2048 property: "age".to_string(),
2049 }),
2050 op: BinaryOp::Gt,
2051 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2052 },
2053 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2054 variable: "n".to_string(),
2055 label: Some("Person".to_string()),
2056 input: None,
2057 })),
2058 pushdown_hint: None,
2059 });
2060
2061 let cardinality = estimator.estimate(&filter);
2062
2063 assert!(cardinality > 300.0 && cardinality < 600.0);
2066 }
2067
2068 #[test]
2069 fn test_filter_equality_with_histogram() {
2070 let mut estimator = CardinalityEstimator::new();
2071
2072 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2074 let histogram = EquiDepthHistogram::build(&values, 10);
2075 let stats = ColumnStats::new(100)
2076 .with_range(0.0, 99.0)
2077 .with_histogram(histogram);
2078
2079 estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2080
2081 let filter = LogicalOperator::Filter(FilterOp {
2083 predicate: LogicalExpression::Binary {
2084 left: Box::new(LogicalExpression::Property {
2085 variable: "d".to_string(),
2086 property: "value".to_string(),
2087 }),
2088 op: BinaryOp::Eq,
2089 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2090 },
2091 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2092 variable: "d".to_string(),
2093 label: Some("Data".to_string()),
2094 input: None,
2095 })),
2096 pushdown_hint: None,
2097 });
2098
2099 let cardinality = estimator.estimate(&filter);
2100
2101 assert!((1.0..50.0).contains(&cardinality));
2104 }
2105
2106 #[test]
2107 fn test_histogram_min_max() {
2108 let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2109 let histogram = EquiDepthHistogram::build(&values, 2);
2110
2111 assert_eq!(histogram.min_value(), Some(5.0));
2112 assert!(histogram.max_value().is_some());
2114 }
2115
2116 #[test]
2119 fn test_selectivity_config_defaults() {
2120 let config = SelectivityConfig::new();
2121 assert!((config.default - 0.1).abs() < f64::EPSILON);
2122 assert!((config.equality - 0.01).abs() < f64::EPSILON);
2123 assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2124 assert!((config.range - 0.33).abs() < f64::EPSILON);
2125 assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2126 assert!((config.membership - 0.1).abs() < f64::EPSILON);
2127 assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2128 assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2129 assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2130 }
2131
2132 #[test]
2133 fn test_custom_selectivity_config() {
2134 let config = SelectivityConfig {
2135 equality: 0.05,
2136 range: 0.25,
2137 ..SelectivityConfig::new()
2138 };
2139 let estimator = CardinalityEstimator::with_selectivity_config(config);
2140 assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2141 assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2142 }
2143
2144 #[test]
2145 fn test_custom_selectivity_affects_estimation() {
2146 let mut default_est = CardinalityEstimator::new();
2148 default_est.add_table_stats("Person", TableStats::new(1000));
2149
2150 let filter = LogicalOperator::Filter(FilterOp {
2151 predicate: LogicalExpression::Binary {
2152 left: Box::new(LogicalExpression::Property {
2153 variable: "n".to_string(),
2154 property: "name".to_string(),
2155 }),
2156 op: BinaryOp::Eq,
2157 right: Box::new(LogicalExpression::Literal(Value::String("Alix".into()))),
2158 },
2159 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2160 variable: "n".to_string(),
2161 label: Some("Person".to_string()),
2162 input: None,
2163 })),
2164 pushdown_hint: None,
2165 });
2166
2167 let default_card = default_est.estimate(&filter);
2168
2169 let config = SelectivityConfig {
2171 equality: 0.2,
2172 ..SelectivityConfig::new()
2173 };
2174 let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2175 custom_est.add_table_stats("Person", TableStats::new(1000));
2176
2177 let custom_card = custom_est.estimate(&filter);
2178
2179 assert!(custom_card > default_card);
2180 assert!((custom_card - 200.0).abs() < 1.0);
2181 }
2182
2183 #[test]
2184 fn test_custom_range_selectivity() {
2185 let config = SelectivityConfig {
2186 range: 0.5,
2187 ..SelectivityConfig::new()
2188 };
2189 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2190 estimator.add_table_stats("Person", TableStats::new(1000));
2191
2192 let filter = LogicalOperator::Filter(FilterOp {
2193 predicate: LogicalExpression::Binary {
2194 left: Box::new(LogicalExpression::Property {
2195 variable: "n".to_string(),
2196 property: "age".to_string(),
2197 }),
2198 op: BinaryOp::Gt,
2199 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2200 },
2201 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2202 variable: "n".to_string(),
2203 label: Some("Person".to_string()),
2204 input: None,
2205 })),
2206 pushdown_hint: None,
2207 });
2208
2209 let cardinality = estimator.estimate(&filter);
2210 assert!((cardinality - 500.0).abs() < 1.0);
2212 }
2213
2214 #[test]
2215 fn test_custom_distinct_fraction() {
2216 let config = SelectivityConfig {
2217 distinct_fraction: 0.8,
2218 ..SelectivityConfig::new()
2219 };
2220 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2221 estimator.add_table_stats("Person", TableStats::new(1000));
2222
2223 let distinct = LogicalOperator::Distinct(DistinctOp {
2224 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2225 variable: "n".to_string(),
2226 label: Some("Person".to_string()),
2227 input: None,
2228 })),
2229 columns: None,
2230 });
2231
2232 let cardinality = estimator.estimate(&distinct);
2233 assert!((cardinality - 800.0).abs() < 1.0);
2235 }
2236
2237 #[test]
2240 fn test_estimation_log_basic() {
2241 let mut log = EstimationLog::new(10.0);
2242 log.record("NodeScan(Person)", 1000.0, 1200.0);
2243 log.record("Filter(age > 30)", 100.0, 90.0);
2244
2245 assert_eq!(log.entries().len(), 2);
2246 assert!(!log.should_replan()); }
2248
2249 #[test]
2250 fn test_estimation_log_triggers_replan() {
2251 let mut log = EstimationLog::new(10.0);
2252 log.record("NodeScan(Person)", 100.0, 5000.0); assert!(log.should_replan());
2255 }
2256
2257 #[test]
2258 fn test_estimation_log_overestimate_triggers_replan() {
2259 let mut log = EstimationLog::new(5.0);
2260 log.record("Filter", 1000.0, 100.0); assert!(log.should_replan()); }
2264
2265 #[test]
2266 fn test_estimation_entry_error_ratio() {
2267 let entry = EstimationEntry {
2268 operator: "test".into(),
2269 estimated: 100.0,
2270 actual: 200.0,
2271 };
2272 assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2273
2274 let perfect = EstimationEntry {
2275 operator: "test".into(),
2276 estimated: 100.0,
2277 actual: 100.0,
2278 };
2279 assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2280
2281 let zero_est = EstimationEntry {
2282 operator: "test".into(),
2283 estimated: 0.0,
2284 actual: 0.0,
2285 };
2286 assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2287 }
2288
2289 #[test]
2290 fn test_estimation_log_max_error_ratio() {
2291 let mut log = EstimationLog::new(10.0);
2292 log.record("A", 100.0, 300.0); log.record("B", 100.0, 50.0); log.record("C", 100.0, 100.0); assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2297 }
2298
2299 #[test]
2300 fn test_estimation_log_clear() {
2301 let mut log = EstimationLog::new(10.0);
2302 log.record("A", 100.0, 100.0);
2303 assert_eq!(log.entries().len(), 1);
2304
2305 log.clear();
2306 assert!(log.entries().is_empty());
2307 assert!(!log.should_replan());
2308 }
2309
2310 #[test]
2311 fn test_create_estimation_log() {
2312 let log = CardinalityEstimator::create_estimation_log();
2313 assert!(log.entries().is_empty());
2314 assert!(!log.should_replan());
2315 }
2316}