1use crate::query::plan::{
18 AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp, LimitOp,
19 LogicalExpression, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, SkipOp, SortOp,
20 UnaryOp, VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31 pub lower_bound: f64,
33 pub upper_bound: f64,
35 pub frequency: u64,
37 pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42 #[must_use]
44 pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45 Self {
46 lower_bound,
47 upper_bound,
48 frequency,
49 distinct_count,
50 }
51 }
52
53 #[must_use]
55 pub fn width(&self) -> f64 {
56 self.upper_bound - self.lower_bound
57 }
58
59 #[must_use]
61 pub fn contains(&self, value: f64) -> bool {
62 value >= self.lower_bound && value < self.upper_bound
63 }
64
65 #[must_use]
67 pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68 let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69 let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71 let bucket_width = self.width();
72 if bucket_width <= 0.0 {
73 return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74 1.0
75 } else {
76 0.0
77 };
78 }
79
80 let overlap = (effective_upper - effective_lower).max(0.0);
81 (overlap / bucket_width).min(1.0)
82 }
83}
84
85#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106 buckets: Vec<HistogramBucket>,
108 total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113 #[must_use]
115 pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116 let total_rows = buckets.iter().map(|b| b.frequency).sum();
117 Self {
118 buckets,
119 total_rows,
120 }
121 }
122
123 #[must_use]
132 pub fn build(values: &[f64], num_buckets: usize) -> Self {
133 if values.is_empty() || num_buckets == 0 {
134 return Self {
135 buckets: Vec::new(),
136 total_rows: 0,
137 };
138 }
139
140 let num_buckets = num_buckets.min(values.len());
141 let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142 let mut buckets = Vec::with_capacity(num_buckets);
143
144 let mut start_idx = 0;
145 while start_idx < values.len() {
146 let end_idx = (start_idx + rows_per_bucket).min(values.len());
147 let lower_bound = values[start_idx];
148 let upper_bound = if end_idx < values.len() {
149 values[end_idx]
150 } else {
151 values[end_idx - 1] + 1.0
153 };
154
155 let bucket_values = &values[start_idx..end_idx];
157 let distinct_count = count_distinct(bucket_values);
158
159 buckets.push(HistogramBucket::new(
160 lower_bound,
161 upper_bound,
162 (end_idx - start_idx) as u64,
163 distinct_count,
164 ));
165
166 start_idx = end_idx;
167 }
168
169 Self::new(buckets)
170 }
171
172 #[must_use]
174 pub fn num_buckets(&self) -> usize {
175 self.buckets.len()
176 }
177
178 #[must_use]
180 pub fn total_rows(&self) -> u64 {
181 self.total_rows
182 }
183
184 #[must_use]
186 pub fn buckets(&self) -> &[HistogramBucket] {
187 &self.buckets
188 }
189
190 #[must_use]
199 pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200 if self.buckets.is_empty() || self.total_rows == 0 {
201 return 0.33; }
203
204 let mut matching_rows = 0.0;
205
206 for bucket in &self.buckets {
207 let bucket_lower = bucket.lower_bound;
209 let bucket_upper = bucket.upper_bound;
210
211 if let Some(l) = lower
213 && bucket_upper <= l
214 {
215 continue;
216 }
217 if let Some(u) = upper
218 && bucket_lower >= u
219 {
220 continue;
221 }
222
223 let overlap = bucket.overlap_fraction(lower, upper);
225 matching_rows += overlap * bucket.frequency as f64;
226 }
227
228 (matching_rows / self.total_rows as f64).clamp(0.0, 1.0)
229 }
230
231 #[must_use]
235 pub fn equality_selectivity(&self, value: f64) -> f64 {
236 if self.buckets.is_empty() || self.total_rows == 0 {
237 return 0.01; }
239
240 for bucket in &self.buckets {
242 if bucket.contains(value) {
243 if bucket.distinct_count > 0 {
245 return (bucket.frequency as f64
246 / bucket.distinct_count as f64
247 / self.total_rows as f64)
248 .min(1.0);
249 }
250 }
251 }
252
253 0.001
255 }
256
257 #[must_use]
259 pub fn min_value(&self) -> Option<f64> {
260 self.buckets.first().map(|b| b.lower_bound)
261 }
262
263 #[must_use]
265 pub fn max_value(&self) -> Option<f64> {
266 self.buckets.last().map(|b| b.upper_bound)
267 }
268}
269
270fn count_distinct(sorted_values: &[f64]) -> u64 {
272 if sorted_values.is_empty() {
273 return 0;
274 }
275
276 let mut count = 1u64;
277 let mut prev = sorted_values[0];
278
279 for &val in &sorted_values[1..] {
280 if (val - prev).abs() > f64::EPSILON {
281 count += 1;
282 prev = val;
283 }
284 }
285
286 count
287}
288
289#[derive(Debug, Clone)]
291pub struct TableStats {
292 pub row_count: u64,
294 pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299 #[must_use]
301 pub fn new(row_count: u64) -> Self {
302 Self {
303 row_count,
304 columns: HashMap::new(),
305 }
306 }
307
308 pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310 self.columns.insert(name.to_string(), stats);
311 self
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct ColumnStats {
318 pub distinct_count: u64,
320 pub null_count: u64,
322 pub min_value: Option<f64>,
324 pub max_value: Option<f64>,
326 pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331 #[must_use]
333 pub fn new(distinct_count: u64) -> Self {
334 Self {
335 distinct_count,
336 null_count: 0,
337 min_value: None,
338 max_value: None,
339 histogram: None,
340 }
341 }
342
343 #[must_use]
345 pub fn with_nulls(mut self, null_count: u64) -> Self {
346 self.null_count = null_count;
347 self
348 }
349
350 #[must_use]
352 pub fn with_range(mut self, min: f64, max: f64) -> Self {
353 self.min_value = Some(min);
354 self.max_value = Some(max);
355 self
356 }
357
358 #[must_use]
360 pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361 self.histogram = Some(histogram);
362 self
363 }
364
365 #[must_use]
373 pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374 if values.is_empty() {
375 return Self::new(0);
376 }
377
378 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381 let min = values.first().copied();
382 let max = values.last().copied();
383 let distinct_count = count_distinct(&values);
384 let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386 Self {
387 distinct_count,
388 null_count: 0,
389 min_value: min,
390 max_value: max,
391 histogram: Some(histogram),
392 }
393 }
394}
395
396#[derive(Debug, Clone)]
402pub struct SelectivityConfig {
403 pub default: f64,
405 pub equality: f64,
407 pub inequality: f64,
409 pub range: f64,
411 pub string_ops: f64,
413 pub membership: f64,
415 pub is_null: f64,
417 pub is_not_null: f64,
419 pub distinct_fraction: f64,
421}
422
423impl SelectivityConfig {
424 #[must_use]
426 pub fn new() -> Self {
427 Self {
428 default: 0.1,
429 equality: 0.01,
430 inequality: 0.99,
431 range: 0.33,
432 string_ops: 0.1,
433 membership: 0.1,
434 is_null: 0.05,
435 is_not_null: 0.95,
436 distinct_fraction: 0.5,
437 }
438 }
439}
440
441impl Default for SelectivityConfig {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447#[derive(Debug, Clone)]
449pub struct EstimationEntry {
450 pub operator: String,
452 pub estimated: f64,
454 pub actual: f64,
456}
457
458impl EstimationEntry {
459 #[must_use]
465 pub fn error_ratio(&self) -> f64 {
466 if self.estimated.abs() < f64::EPSILON {
467 if self.actual.abs() < f64::EPSILON {
468 1.0
469 } else {
470 f64::INFINITY
471 }
472 } else {
473 self.actual / self.estimated
474 }
475 }
476}
477
478#[derive(Debug, Clone, Default)]
485pub struct EstimationLog {
486 entries: Vec<EstimationEntry>,
488 replan_threshold: f64,
492}
493
494impl EstimationLog {
495 #[must_use]
497 pub fn new(replan_threshold: f64) -> Self {
498 Self {
499 entries: Vec::new(),
500 replan_threshold,
501 }
502 }
503
504 pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
506 self.entries.push(EstimationEntry {
507 operator: operator.into(),
508 estimated,
509 actual,
510 });
511 }
512
513 #[must_use]
515 pub fn entries(&self) -> &[EstimationEntry] {
516 &self.entries
517 }
518
519 #[must_use]
522 pub fn should_replan(&self) -> bool {
523 self.entries.iter().any(|e| {
524 let ratio = e.error_ratio();
525 ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
526 })
527 }
528
529 #[must_use]
531 pub fn max_error_ratio(&self) -> f64 {
532 self.entries
533 .iter()
534 .map(|e| {
535 let r = e.error_ratio();
536 if r < 1.0 { 1.0 / r } else { r }
538 })
539 .fold(1.0_f64, f64::max)
540 }
541
542 pub fn clear(&mut self) {
544 self.entries.clear();
545 }
546}
547
548pub struct CardinalityEstimator {
550 table_stats: HashMap<String, TableStats>,
552 default_row_count: u64,
554 default_selectivity: f64,
556 avg_fanout: f64,
558 selectivity_config: SelectivityConfig,
560}
561
562impl CardinalityEstimator {
563 #[must_use]
565 pub fn new() -> Self {
566 let config = SelectivityConfig::new();
567 Self {
568 table_stats: HashMap::new(),
569 default_row_count: 1000,
570 default_selectivity: config.default,
571 avg_fanout: 10.0,
572 selectivity_config: config,
573 }
574 }
575
576 #[must_use]
578 pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
579 Self {
580 table_stats: HashMap::new(),
581 default_row_count: 1000,
582 default_selectivity: config.default,
583 avg_fanout: 10.0,
584 selectivity_config: config,
585 }
586 }
587
588 #[must_use]
590 pub fn selectivity_config(&self) -> &SelectivityConfig {
591 &self.selectivity_config
592 }
593
594 #[must_use]
596 pub fn create_estimation_log() -> EstimationLog {
597 EstimationLog::new(10.0)
598 }
599
600 #[must_use]
606 pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
607 let mut estimator = Self::new();
608
609 if stats.total_nodes > 0 {
611 estimator.default_row_count = stats.total_nodes;
612 }
613
614 for (label, label_stats) in &stats.labels {
616 let mut table_stats = TableStats::new(label_stats.node_count);
617
618 for (prop, col_stats) in &label_stats.properties {
620 let optimizer_col =
621 ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
622 table_stats = table_stats.with_column(prop, optimizer_col);
623 }
624
625 estimator.add_table_stats(label, table_stats);
626 }
627
628 if !stats.edge_types.is_empty() {
630 let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
631 estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
632 } else if stats.total_nodes > 0 {
633 estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
634 }
635
636 if estimator.avg_fanout < 1.0 {
638 estimator.avg_fanout = 1.0;
639 }
640
641 estimator
642 }
643
644 pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
646 self.table_stats.insert(name.to_string(), stats);
647 }
648
649 pub fn set_avg_fanout(&mut self, fanout: f64) {
651 self.avg_fanout = fanout;
652 }
653
654 #[must_use]
656 pub fn estimate(&self, op: &LogicalOperator) -> f64 {
657 match op {
658 LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
659 LogicalOperator::Filter(filter) => self.estimate_filter(filter),
660 LogicalOperator::Project(project) => self.estimate_project(project),
661 LogicalOperator::Expand(expand) => self.estimate_expand(expand),
662 LogicalOperator::Join(join) => self.estimate_join(join),
663 LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
664 LogicalOperator::Sort(sort) => self.estimate_sort(sort),
665 LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
666 LogicalOperator::Limit(limit) => self.estimate_limit(limit),
667 LogicalOperator::Skip(skip) => self.estimate_skip(skip),
668 LogicalOperator::Return(ret) => self.estimate(&ret.input),
669 LogicalOperator::Empty => 0.0,
670 LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
671 LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
672 LogicalOperator::MultiWayJoin(mwj) => self.estimate_multi_way_join(mwj),
673 LogicalOperator::LeftJoin(lj) => self.estimate_left_join(lj),
674 _ => self.default_row_count as f64,
675 }
676 }
677
678 fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
680 if let Some(label) = &scan.label
681 && let Some(stats) = self.table_stats.get(label)
682 {
683 return stats.row_count as f64;
684 }
685 self.default_row_count as f64
687 }
688
689 fn estimate_filter(&self, filter: &FilterOp) -> f64 {
691 let input_cardinality = self.estimate(&filter.input);
692 let selectivity = self.estimate_selectivity(&filter.predicate);
693 (input_cardinality * selectivity).max(1.0)
694 }
695
696 fn estimate_project(&self, project: &ProjectOp) -> f64 {
698 self.estimate(&project.input)
699 }
700
701 fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
703 let input_cardinality = self.estimate(&expand.input);
704
705 let fanout = if !expand.edge_types.is_empty() {
707 self.avg_fanout * 0.5
709 } else {
710 self.avg_fanout
711 };
712
713 let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
715 let min = expand.min_hops as f64;
716 let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
717 (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
719 } else {
720 fanout
721 };
722
723 (input_cardinality * path_multiplier).max(1.0)
724 }
725
726 fn estimate_join(&self, join: &JoinOp) -> f64 {
728 let left_card = self.estimate(&join.left);
729 let right_card = self.estimate(&join.right);
730
731 match join.join_type {
732 JoinType::Cross => left_card * right_card,
733 JoinType::Inner => {
734 let selectivity = if join.conditions.is_empty() {
736 1.0 } else {
738 0.1_f64.powi(join.conditions.len() as i32)
740 };
741 (left_card * right_card * selectivity).max(1.0)
742 }
743 JoinType::Left => {
744 let inner_card = self.estimate_join(&JoinOp {
746 left: join.left.clone(),
747 right: join.right.clone(),
748 join_type: JoinType::Inner,
749 conditions: join.conditions.clone(),
750 });
751 inner_card.max(left_card)
752 }
753 JoinType::Right => {
754 let inner_card = self.estimate_join(&JoinOp {
756 left: join.left.clone(),
757 right: join.right.clone(),
758 join_type: JoinType::Inner,
759 conditions: join.conditions.clone(),
760 });
761 inner_card.max(right_card)
762 }
763 JoinType::Full => {
764 let inner_card = self.estimate_join(&JoinOp {
766 left: join.left.clone(),
767 right: join.right.clone(),
768 join_type: JoinType::Inner,
769 conditions: join.conditions.clone(),
770 });
771 inner_card.max(left_card.max(right_card))
772 }
773 JoinType::Semi => {
774 (left_card * self.default_selectivity).max(1.0)
776 }
777 JoinType::Anti => {
778 (left_card * (1.0 - self.default_selectivity)).max(1.0)
780 }
781 }
782 }
783
784 fn estimate_left_join(&self, lj: &LeftJoinOp) -> f64 {
790 let left_card = self.estimate(&lj.left);
791 let right_card = self.estimate(&lj.right);
792
793 let inner_estimate = left_card * right_card * self.default_selectivity;
795 inner_estimate.max(left_card).max(1.0)
796 }
797
798 fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
800 let input_cardinality = self.estimate(&agg.input);
801
802 if agg.group_by.is_empty() {
803 1.0
805 } else {
806 let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
809 (input_cardinality / group_reduction).max(1.0)
810 }
811 }
812
813 fn estimate_sort(&self, sort: &SortOp) -> f64 {
815 self.estimate(&sort.input)
816 }
817
818 fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
820 let input_cardinality = self.estimate(&distinct.input);
821 (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
822 }
823
824 fn estimate_limit(&self, limit: &LimitOp) -> f64 {
826 let input_cardinality = self.estimate(&limit.input);
827 limit.count.estimate().min(input_cardinality)
828 }
829
830 fn estimate_skip(&self, skip: &SkipOp) -> f64 {
832 let input_cardinality = self.estimate(&skip.input);
833 (input_cardinality - skip.count.estimate()).max(0.0)
834 }
835
836 fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
841 let base_k = scan.k as f64;
842
843 let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
845 0.7
847 } else {
848 1.0
849 };
850
851 (base_k * selectivity).max(1.0)
852 }
853
854 fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
858 let input_cardinality = self.estimate(&join.input);
859 let k = join.k as f64;
860
861 let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
863 0.7
864 } else {
865 1.0
866 };
867
868 (input_cardinality * k * selectivity).max(1.0)
869 }
870
871 fn estimate_multi_way_join(&self, mwj: &MultiWayJoinOp) -> f64 {
876 if mwj.inputs.is_empty() {
877 return 0.0;
878 }
879 let cardinalities: Vec<f64> = mwj
880 .inputs
881 .iter()
882 .map(|input| self.estimate(input))
883 .collect();
884 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
885 let n = cardinalities.len() as f64;
886 (min_card.powf(n / 2.0)).max(1.0)
888 }
889
890 fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
892 match expr {
893 LogicalExpression::Binary { left, op, right } => {
894 self.estimate_binary_selectivity(left, *op, right)
895 }
896 LogicalExpression::Unary { op, operand } => {
897 self.estimate_unary_selectivity(*op, operand)
898 }
899 LogicalExpression::Literal(value) => {
900 if let grafeo_common::types::Value::Bool(b) = value {
902 if *b { 1.0 } else { 0.0 }
903 } else {
904 self.default_selectivity
905 }
906 }
907 _ => self.default_selectivity,
908 }
909 }
910
911 fn estimate_binary_selectivity(
913 &self,
914 left: &LogicalExpression,
915 op: BinaryOp,
916 right: &LogicalExpression,
917 ) -> f64 {
918 match op {
919 BinaryOp::Eq => {
921 if let Some(selectivity) = self.try_equality_selectivity(left, right) {
922 return selectivity;
923 }
924 self.selectivity_config.equality
925 }
926 BinaryOp::Ne => self.selectivity_config.inequality,
928 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
930 if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
931 return selectivity;
932 }
933 self.selectivity_config.range
934 }
935 BinaryOp::And => {
937 let left_sel = self.estimate_selectivity(left);
938 let right_sel = self.estimate_selectivity(right);
939 left_sel * right_sel
941 }
942 BinaryOp::Or => {
943 let left_sel = self.estimate_selectivity(left);
944 let right_sel = self.estimate_selectivity(right);
945 (left_sel + right_sel - left_sel * right_sel).min(1.0)
948 }
949 BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
951 self.selectivity_config.string_ops
952 }
953 BinaryOp::In => self.selectivity_config.membership,
955 _ => self.default_selectivity,
957 }
958 }
959
960 fn try_equality_selectivity(
962 &self,
963 left: &LogicalExpression,
964 right: &LogicalExpression,
965 ) -> Option<f64> {
966 let (label, column, value) = self.extract_column_and_value(left, right)?;
968
969 let stats = self.get_column_stats(&label, &column)?;
971
972 if let Some(ref histogram) = stats.histogram {
974 return Some(histogram.equality_selectivity(value));
975 }
976
977 if stats.distinct_count > 0 {
979 return Some(1.0 / stats.distinct_count as f64);
980 }
981
982 None
983 }
984
985 fn try_range_selectivity(
987 &self,
988 left: &LogicalExpression,
989 op: BinaryOp,
990 right: &LogicalExpression,
991 ) -> Option<f64> {
992 let (label, column, value) = self.extract_column_and_value(left, right)?;
994
995 let stats = self.get_column_stats(&label, &column)?;
997
998 let (lower, upper) = match op {
1000 BinaryOp::Lt => (None, Some(value)),
1001 BinaryOp::Le => (None, Some(value + f64::EPSILON)),
1002 BinaryOp::Gt => (Some(value + f64::EPSILON), None),
1003 BinaryOp::Ge => (Some(value), None),
1004 _ => return None,
1005 };
1006
1007 if let Some(ref histogram) = stats.histogram {
1009 return Some(histogram.range_selectivity(lower, upper));
1010 }
1011
1012 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
1014 let range = max - min;
1015 if range <= 0.0 {
1016 return Some(1.0);
1017 }
1018
1019 let effective_lower = lower.unwrap_or(min).max(min);
1020 let effective_upper = upper.unwrap_or(max).min(max);
1021 let overlap = (effective_upper - effective_lower).max(0.0);
1022 return Some((overlap / range).clamp(0.0, 1.0));
1023 }
1024
1025 None
1026 }
1027
1028 fn extract_column_and_value(
1033 &self,
1034 left: &LogicalExpression,
1035 right: &LogicalExpression,
1036 ) -> Option<(String, String, f64)> {
1037 if let Some(result) = self.try_extract_property_literal(left, right) {
1039 return Some(result);
1040 }
1041
1042 self.try_extract_property_literal(right, left)
1044 }
1045
1046 fn try_extract_property_literal(
1048 &self,
1049 property_expr: &LogicalExpression,
1050 literal_expr: &LogicalExpression,
1051 ) -> Option<(String, String, f64)> {
1052 let (variable, property) = match property_expr {
1054 LogicalExpression::Property { variable, property } => {
1055 (variable.clone(), property.clone())
1056 }
1057 _ => return None,
1058 };
1059
1060 let value = match literal_expr {
1062 LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1063 LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1064 _ => return None,
1065 };
1066
1067 for label in self.table_stats.keys() {
1071 if let Some(stats) = self.table_stats.get(label)
1072 && stats.columns.contains_key(&property)
1073 {
1074 return Some((label.clone(), property, value));
1075 }
1076 }
1077
1078 Some((variable, property, value))
1080 }
1081
1082 fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1084 match op {
1085 UnaryOp::Not => 1.0 - self.default_selectivity,
1086 UnaryOp::IsNull => self.selectivity_config.is_null,
1087 UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1088 UnaryOp::Neg => 1.0, }
1090 }
1091
1092 fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1094 self.table_stats.get(label)?.columns.get(column)
1095 }
1096}
1097
1098impl Default for CardinalityEstimator {
1099 fn default() -> Self {
1100 Self::new()
1101 }
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106 use super::*;
1107 use crate::query::plan::{
1108 DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1109 ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1110 };
1111 use grafeo_common::types::Value;
1112
1113 #[test]
1114 fn test_node_scan_with_stats() {
1115 let mut estimator = CardinalityEstimator::new();
1116 estimator.add_table_stats("Person", TableStats::new(5000));
1117
1118 let scan = LogicalOperator::NodeScan(NodeScanOp {
1119 variable: "n".to_string(),
1120 label: Some("Person".to_string()),
1121 input: None,
1122 });
1123
1124 let cardinality = estimator.estimate(&scan);
1125 assert!((cardinality - 5000.0).abs() < 0.001);
1126 }
1127
1128 #[test]
1129 fn test_filter_reduces_cardinality() {
1130 let mut estimator = CardinalityEstimator::new();
1131 estimator.add_table_stats("Person", TableStats::new(1000));
1132
1133 let filter = LogicalOperator::Filter(FilterOp {
1134 predicate: LogicalExpression::Binary {
1135 left: Box::new(LogicalExpression::Property {
1136 variable: "n".to_string(),
1137 property: "age".to_string(),
1138 }),
1139 op: BinaryOp::Eq,
1140 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1141 },
1142 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1143 variable: "n".to_string(),
1144 label: Some("Person".to_string()),
1145 input: None,
1146 })),
1147 pushdown_hint: None,
1148 });
1149
1150 let cardinality = estimator.estimate(&filter);
1151 assert!(cardinality < 1000.0);
1153 assert!(cardinality >= 1.0);
1154 }
1155
1156 #[test]
1157 fn test_join_cardinality() {
1158 let mut estimator = CardinalityEstimator::new();
1159 estimator.add_table_stats("Person", TableStats::new(1000));
1160 estimator.add_table_stats("Company", TableStats::new(100));
1161
1162 let join = LogicalOperator::Join(JoinOp {
1163 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1164 variable: "p".to_string(),
1165 label: Some("Person".to_string()),
1166 input: None,
1167 })),
1168 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1169 variable: "c".to_string(),
1170 label: Some("Company".to_string()),
1171 input: None,
1172 })),
1173 join_type: JoinType::Inner,
1174 conditions: vec![JoinCondition {
1175 left: LogicalExpression::Property {
1176 variable: "p".to_string(),
1177 property: "company_id".to_string(),
1178 },
1179 right: LogicalExpression::Property {
1180 variable: "c".to_string(),
1181 property: "id".to_string(),
1182 },
1183 }],
1184 });
1185
1186 let cardinality = estimator.estimate(&join);
1187 assert!(cardinality < 1000.0 * 100.0);
1189 }
1190
1191 #[test]
1192 fn test_limit_caps_cardinality() {
1193 let mut estimator = CardinalityEstimator::new();
1194 estimator.add_table_stats("Person", TableStats::new(1000));
1195
1196 let limit = LogicalOperator::Limit(LimitOp {
1197 count: 10.into(),
1198 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1199 variable: "n".to_string(),
1200 label: Some("Person".to_string()),
1201 input: None,
1202 })),
1203 });
1204
1205 let cardinality = estimator.estimate(&limit);
1206 assert!((cardinality - 10.0).abs() < 0.001);
1207 }
1208
1209 #[test]
1210 fn test_aggregate_reduces_cardinality() {
1211 let mut estimator = CardinalityEstimator::new();
1212 estimator.add_table_stats("Person", TableStats::new(1000));
1213
1214 let global_agg = LogicalOperator::Aggregate(AggregateOp {
1216 group_by: vec![],
1217 aggregates: vec![],
1218 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1219 variable: "n".to_string(),
1220 label: Some("Person".to_string()),
1221 input: None,
1222 })),
1223 having: None,
1224 });
1225
1226 let cardinality = estimator.estimate(&global_agg);
1227 assert!((cardinality - 1.0).abs() < 0.001);
1228
1229 let group_agg = LogicalOperator::Aggregate(AggregateOp {
1231 group_by: vec![LogicalExpression::Property {
1232 variable: "n".to_string(),
1233 property: "city".to_string(),
1234 }],
1235 aggregates: vec![],
1236 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1237 variable: "n".to_string(),
1238 label: Some("Person".to_string()),
1239 input: None,
1240 })),
1241 having: None,
1242 });
1243
1244 let cardinality = estimator.estimate(&group_agg);
1245 assert!(cardinality < 1000.0);
1247 }
1248
1249 #[test]
1250 fn test_node_scan_without_stats() {
1251 let estimator = CardinalityEstimator::new();
1252
1253 let scan = LogicalOperator::NodeScan(NodeScanOp {
1254 variable: "n".to_string(),
1255 label: Some("Unknown".to_string()),
1256 input: None,
1257 });
1258
1259 let cardinality = estimator.estimate(&scan);
1260 assert!((cardinality - 1000.0).abs() < 0.001);
1262 }
1263
1264 #[test]
1265 fn test_node_scan_no_label() {
1266 let estimator = CardinalityEstimator::new();
1267
1268 let scan = LogicalOperator::NodeScan(NodeScanOp {
1269 variable: "n".to_string(),
1270 label: None,
1271 input: None,
1272 });
1273
1274 let cardinality = estimator.estimate(&scan);
1275 assert!((cardinality - 1000.0).abs() < 0.001);
1277 }
1278
1279 #[test]
1280 fn test_filter_inequality_selectivity() {
1281 let mut estimator = CardinalityEstimator::new();
1282 estimator.add_table_stats("Person", TableStats::new(1000));
1283
1284 let filter = LogicalOperator::Filter(FilterOp {
1285 predicate: LogicalExpression::Binary {
1286 left: Box::new(LogicalExpression::Property {
1287 variable: "n".to_string(),
1288 property: "age".to_string(),
1289 }),
1290 op: BinaryOp::Ne,
1291 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1292 },
1293 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1294 variable: "n".to_string(),
1295 label: Some("Person".to_string()),
1296 input: None,
1297 })),
1298 pushdown_hint: None,
1299 });
1300
1301 let cardinality = estimator.estimate(&filter);
1302 assert!(cardinality > 900.0);
1304 }
1305
1306 #[test]
1307 fn test_filter_range_selectivity() {
1308 let mut estimator = CardinalityEstimator::new();
1309 estimator.add_table_stats("Person", TableStats::new(1000));
1310
1311 let filter = LogicalOperator::Filter(FilterOp {
1312 predicate: LogicalExpression::Binary {
1313 left: Box::new(LogicalExpression::Property {
1314 variable: "n".to_string(),
1315 property: "age".to_string(),
1316 }),
1317 op: BinaryOp::Gt,
1318 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1319 },
1320 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1321 variable: "n".to_string(),
1322 label: Some("Person".to_string()),
1323 input: None,
1324 })),
1325 pushdown_hint: None,
1326 });
1327
1328 let cardinality = estimator.estimate(&filter);
1329 assert!(cardinality < 500.0);
1331 assert!(cardinality > 100.0);
1332 }
1333
1334 #[test]
1335 fn test_filter_and_selectivity() {
1336 let mut estimator = CardinalityEstimator::new();
1337 estimator.add_table_stats("Person", TableStats::new(1000));
1338
1339 let filter = LogicalOperator::Filter(FilterOp {
1342 predicate: LogicalExpression::Binary {
1343 left: Box::new(LogicalExpression::Binary {
1344 left: Box::new(LogicalExpression::Property {
1345 variable: "n".to_string(),
1346 property: "city".to_string(),
1347 }),
1348 op: BinaryOp::Eq,
1349 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1350 }),
1351 op: BinaryOp::And,
1352 right: Box::new(LogicalExpression::Binary {
1353 left: Box::new(LogicalExpression::Property {
1354 variable: "n".to_string(),
1355 property: "age".to_string(),
1356 }),
1357 op: BinaryOp::Eq,
1358 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1359 }),
1360 },
1361 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1362 variable: "n".to_string(),
1363 label: Some("Person".to_string()),
1364 input: None,
1365 })),
1366 pushdown_hint: None,
1367 });
1368
1369 let cardinality = estimator.estimate(&filter);
1370 assert!(cardinality < 100.0);
1373 assert!(cardinality >= 1.0);
1374 }
1375
1376 #[test]
1377 fn test_filter_or_selectivity() {
1378 let mut estimator = CardinalityEstimator::new();
1379 estimator.add_table_stats("Person", TableStats::new(1000));
1380
1381 let filter = LogicalOperator::Filter(FilterOp {
1385 predicate: LogicalExpression::Binary {
1386 left: Box::new(LogicalExpression::Binary {
1387 left: Box::new(LogicalExpression::Property {
1388 variable: "n".to_string(),
1389 property: "city".to_string(),
1390 }),
1391 op: BinaryOp::Eq,
1392 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1393 }),
1394 op: BinaryOp::Or,
1395 right: Box::new(LogicalExpression::Binary {
1396 left: Box::new(LogicalExpression::Property {
1397 variable: "n".to_string(),
1398 property: "city".to_string(),
1399 }),
1400 op: BinaryOp::Eq,
1401 right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1402 }),
1403 },
1404 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1405 variable: "n".to_string(),
1406 label: Some("Person".to_string()),
1407 input: None,
1408 })),
1409 pushdown_hint: None,
1410 });
1411
1412 let cardinality = estimator.estimate(&filter);
1413 assert!(cardinality < 100.0);
1415 assert!(cardinality >= 1.0);
1416 }
1417
1418 #[test]
1419 fn test_filter_literal_true() {
1420 let mut estimator = CardinalityEstimator::new();
1421 estimator.add_table_stats("Person", TableStats::new(1000));
1422
1423 let filter = LogicalOperator::Filter(FilterOp {
1424 predicate: LogicalExpression::Literal(Value::Bool(true)),
1425 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1426 variable: "n".to_string(),
1427 label: Some("Person".to_string()),
1428 input: None,
1429 })),
1430 pushdown_hint: None,
1431 });
1432
1433 let cardinality = estimator.estimate(&filter);
1434 assert!((cardinality - 1000.0).abs() < 0.001);
1436 }
1437
1438 #[test]
1439 fn test_filter_literal_false() {
1440 let mut estimator = CardinalityEstimator::new();
1441 estimator.add_table_stats("Person", TableStats::new(1000));
1442
1443 let filter = LogicalOperator::Filter(FilterOp {
1444 predicate: LogicalExpression::Literal(Value::Bool(false)),
1445 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1446 variable: "n".to_string(),
1447 label: Some("Person".to_string()),
1448 input: None,
1449 })),
1450 pushdown_hint: None,
1451 });
1452
1453 let cardinality = estimator.estimate(&filter);
1454 assert!((cardinality - 1.0).abs() < 0.001);
1456 }
1457
1458 #[test]
1459 fn test_unary_not_selectivity() {
1460 let mut estimator = CardinalityEstimator::new();
1461 estimator.add_table_stats("Person", TableStats::new(1000));
1462
1463 let filter = LogicalOperator::Filter(FilterOp {
1464 predicate: LogicalExpression::Unary {
1465 op: UnaryOp::Not,
1466 operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1467 },
1468 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1469 variable: "n".to_string(),
1470 label: Some("Person".to_string()),
1471 input: None,
1472 })),
1473 pushdown_hint: None,
1474 });
1475
1476 let cardinality = estimator.estimate(&filter);
1477 assert!(cardinality < 1000.0);
1479 }
1480
1481 #[test]
1482 fn test_unary_is_null_selectivity() {
1483 let mut estimator = CardinalityEstimator::new();
1484 estimator.add_table_stats("Person", TableStats::new(1000));
1485
1486 let filter = LogicalOperator::Filter(FilterOp {
1487 predicate: LogicalExpression::Unary {
1488 op: UnaryOp::IsNull,
1489 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1490 },
1491 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1492 variable: "n".to_string(),
1493 label: Some("Person".to_string()),
1494 input: None,
1495 })),
1496 pushdown_hint: None,
1497 });
1498
1499 let cardinality = estimator.estimate(&filter);
1500 assert!(cardinality < 100.0);
1502 }
1503
1504 #[test]
1505 fn test_expand_cardinality() {
1506 let mut estimator = CardinalityEstimator::new();
1507 estimator.add_table_stats("Person", TableStats::new(100));
1508
1509 let expand = LogicalOperator::Expand(ExpandOp {
1510 from_variable: "a".to_string(),
1511 to_variable: "b".to_string(),
1512 edge_variable: None,
1513 direction: ExpandDirection::Outgoing,
1514 edge_types: vec![],
1515 min_hops: 1,
1516 max_hops: Some(1),
1517 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1518 variable: "a".to_string(),
1519 label: Some("Person".to_string()),
1520 input: None,
1521 })),
1522 path_alias: None,
1523 path_mode: PathMode::Walk,
1524 });
1525
1526 let cardinality = estimator.estimate(&expand);
1527 assert!(cardinality > 100.0);
1529 }
1530
1531 #[test]
1532 fn test_expand_with_edge_type_filter() {
1533 let mut estimator = CardinalityEstimator::new();
1534 estimator.add_table_stats("Person", TableStats::new(100));
1535
1536 let expand = LogicalOperator::Expand(ExpandOp {
1537 from_variable: "a".to_string(),
1538 to_variable: "b".to_string(),
1539 edge_variable: None,
1540 direction: ExpandDirection::Outgoing,
1541 edge_types: vec!["KNOWS".to_string()],
1542 min_hops: 1,
1543 max_hops: Some(1),
1544 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1545 variable: "a".to_string(),
1546 label: Some("Person".to_string()),
1547 input: None,
1548 })),
1549 path_alias: None,
1550 path_mode: PathMode::Walk,
1551 });
1552
1553 let cardinality = estimator.estimate(&expand);
1554 assert!(cardinality > 100.0);
1556 }
1557
1558 #[test]
1559 fn test_expand_variable_length() {
1560 let mut estimator = CardinalityEstimator::new();
1561 estimator.add_table_stats("Person", TableStats::new(100));
1562
1563 let expand = LogicalOperator::Expand(ExpandOp {
1564 from_variable: "a".to_string(),
1565 to_variable: "b".to_string(),
1566 edge_variable: None,
1567 direction: ExpandDirection::Outgoing,
1568 edge_types: vec![],
1569 min_hops: 1,
1570 max_hops: Some(3),
1571 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1572 variable: "a".to_string(),
1573 label: Some("Person".to_string()),
1574 input: None,
1575 })),
1576 path_alias: None,
1577 path_mode: PathMode::Walk,
1578 });
1579
1580 let cardinality = estimator.estimate(&expand);
1581 assert!(cardinality > 500.0);
1583 }
1584
1585 #[test]
1586 fn test_join_cross_product() {
1587 let mut estimator = CardinalityEstimator::new();
1588 estimator.add_table_stats("Person", TableStats::new(100));
1589 estimator.add_table_stats("Company", TableStats::new(50));
1590
1591 let join = LogicalOperator::Join(JoinOp {
1592 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1593 variable: "p".to_string(),
1594 label: Some("Person".to_string()),
1595 input: None,
1596 })),
1597 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1598 variable: "c".to_string(),
1599 label: Some("Company".to_string()),
1600 input: None,
1601 })),
1602 join_type: JoinType::Cross,
1603 conditions: vec![],
1604 });
1605
1606 let cardinality = estimator.estimate(&join);
1607 assert!((cardinality - 5000.0).abs() < 0.001);
1609 }
1610
1611 #[test]
1612 fn test_join_left_outer() {
1613 let mut estimator = CardinalityEstimator::new();
1614 estimator.add_table_stats("Person", TableStats::new(1000));
1615 estimator.add_table_stats("Company", TableStats::new(10));
1616
1617 let join = LogicalOperator::Join(JoinOp {
1618 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1619 variable: "p".to_string(),
1620 label: Some("Person".to_string()),
1621 input: None,
1622 })),
1623 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1624 variable: "c".to_string(),
1625 label: Some("Company".to_string()),
1626 input: None,
1627 })),
1628 join_type: JoinType::Left,
1629 conditions: vec![JoinCondition {
1630 left: LogicalExpression::Variable("p".to_string()),
1631 right: LogicalExpression::Variable("c".to_string()),
1632 }],
1633 });
1634
1635 let cardinality = estimator.estimate(&join);
1636 assert!(cardinality >= 1000.0);
1638 }
1639
1640 #[test]
1641 fn test_join_semi() {
1642 let mut estimator = CardinalityEstimator::new();
1643 estimator.add_table_stats("Person", TableStats::new(1000));
1644 estimator.add_table_stats("Company", TableStats::new(100));
1645
1646 let join = LogicalOperator::Join(JoinOp {
1647 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1648 variable: "p".to_string(),
1649 label: Some("Person".to_string()),
1650 input: None,
1651 })),
1652 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1653 variable: "c".to_string(),
1654 label: Some("Company".to_string()),
1655 input: None,
1656 })),
1657 join_type: JoinType::Semi,
1658 conditions: vec![],
1659 });
1660
1661 let cardinality = estimator.estimate(&join);
1662 assert!(cardinality <= 1000.0);
1664 }
1665
1666 #[test]
1667 fn test_join_anti() {
1668 let mut estimator = CardinalityEstimator::new();
1669 estimator.add_table_stats("Person", TableStats::new(1000));
1670 estimator.add_table_stats("Company", TableStats::new(100));
1671
1672 let join = LogicalOperator::Join(JoinOp {
1673 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1674 variable: "p".to_string(),
1675 label: Some("Person".to_string()),
1676 input: None,
1677 })),
1678 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1679 variable: "c".to_string(),
1680 label: Some("Company".to_string()),
1681 input: None,
1682 })),
1683 join_type: JoinType::Anti,
1684 conditions: vec![],
1685 });
1686
1687 let cardinality = estimator.estimate(&join);
1688 assert!(cardinality <= 1000.0);
1690 assert!(cardinality >= 1.0);
1691 }
1692
1693 #[test]
1694 fn test_project_preserves_cardinality() {
1695 let mut estimator = CardinalityEstimator::new();
1696 estimator.add_table_stats("Person", TableStats::new(1000));
1697
1698 let project = LogicalOperator::Project(ProjectOp {
1699 projections: vec![Projection {
1700 expression: LogicalExpression::Variable("n".to_string()),
1701 alias: None,
1702 }],
1703 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1704 variable: "n".to_string(),
1705 label: Some("Person".to_string()),
1706 input: None,
1707 })),
1708 pass_through_input: false,
1709 });
1710
1711 let cardinality = estimator.estimate(&project);
1712 assert!((cardinality - 1000.0).abs() < 0.001);
1713 }
1714
1715 #[test]
1716 fn test_sort_preserves_cardinality() {
1717 let mut estimator = CardinalityEstimator::new();
1718 estimator.add_table_stats("Person", TableStats::new(1000));
1719
1720 let sort = LogicalOperator::Sort(SortOp {
1721 keys: vec![SortKey {
1722 expression: LogicalExpression::Variable("n".to_string()),
1723 order: SortOrder::Ascending,
1724 nulls: None,
1725 }],
1726 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1727 variable: "n".to_string(),
1728 label: Some("Person".to_string()),
1729 input: None,
1730 })),
1731 });
1732
1733 let cardinality = estimator.estimate(&sort);
1734 assert!((cardinality - 1000.0).abs() < 0.001);
1735 }
1736
1737 #[test]
1738 fn test_distinct_reduces_cardinality() {
1739 let mut estimator = CardinalityEstimator::new();
1740 estimator.add_table_stats("Person", TableStats::new(1000));
1741
1742 let distinct = LogicalOperator::Distinct(DistinctOp {
1743 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1744 variable: "n".to_string(),
1745 label: Some("Person".to_string()),
1746 input: None,
1747 })),
1748 columns: None,
1749 });
1750
1751 let cardinality = estimator.estimate(&distinct);
1752 assert!((cardinality - 500.0).abs() < 0.001);
1754 }
1755
1756 #[test]
1757 fn test_skip_reduces_cardinality() {
1758 let mut estimator = CardinalityEstimator::new();
1759 estimator.add_table_stats("Person", TableStats::new(1000));
1760
1761 let skip = LogicalOperator::Skip(SkipOp {
1762 count: 100.into(),
1763 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1764 variable: "n".to_string(),
1765 label: Some("Person".to_string()),
1766 input: None,
1767 })),
1768 });
1769
1770 let cardinality = estimator.estimate(&skip);
1771 assert!((cardinality - 900.0).abs() < 0.001);
1772 }
1773
1774 #[test]
1775 fn test_return_preserves_cardinality() {
1776 let mut estimator = CardinalityEstimator::new();
1777 estimator.add_table_stats("Person", TableStats::new(1000));
1778
1779 let ret = LogicalOperator::Return(ReturnOp {
1780 items: vec![ReturnItem {
1781 expression: LogicalExpression::Variable("n".to_string()),
1782 alias: None,
1783 }],
1784 distinct: false,
1785 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1786 variable: "n".to_string(),
1787 label: Some("Person".to_string()),
1788 input: None,
1789 })),
1790 });
1791
1792 let cardinality = estimator.estimate(&ret);
1793 assert!((cardinality - 1000.0).abs() < 0.001);
1794 }
1795
1796 #[test]
1797 fn test_empty_cardinality() {
1798 let estimator = CardinalityEstimator::new();
1799 let cardinality = estimator.estimate(&LogicalOperator::Empty);
1800 assert!((cardinality).abs() < 0.001);
1801 }
1802
1803 #[test]
1804 fn test_table_stats_with_column() {
1805 let stats = TableStats::new(1000).with_column(
1806 "age",
1807 ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1808 );
1809
1810 assert_eq!(stats.row_count, 1000);
1811 let col = stats.columns.get("age").unwrap();
1812 assert_eq!(col.distinct_count, 50);
1813 assert_eq!(col.null_count, 10);
1814 assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1815 assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1816 }
1817
1818 #[test]
1819 fn test_estimator_default() {
1820 let estimator = CardinalityEstimator::default();
1821 let scan = LogicalOperator::NodeScan(NodeScanOp {
1822 variable: "n".to_string(),
1823 label: None,
1824 input: None,
1825 });
1826 let cardinality = estimator.estimate(&scan);
1827 assert!((cardinality - 1000.0).abs() < 0.001);
1828 }
1829
1830 #[test]
1831 fn test_set_avg_fanout() {
1832 let mut estimator = CardinalityEstimator::new();
1833 estimator.add_table_stats("Person", TableStats::new(100));
1834 estimator.set_avg_fanout(5.0);
1835
1836 let expand = LogicalOperator::Expand(ExpandOp {
1837 from_variable: "a".to_string(),
1838 to_variable: "b".to_string(),
1839 edge_variable: None,
1840 direction: ExpandDirection::Outgoing,
1841 edge_types: vec![],
1842 min_hops: 1,
1843 max_hops: Some(1),
1844 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1845 variable: "a".to_string(),
1846 label: Some("Person".to_string()),
1847 input: None,
1848 })),
1849 path_alias: None,
1850 path_mode: PathMode::Walk,
1851 });
1852
1853 let cardinality = estimator.estimate(&expand);
1854 assert!((cardinality - 500.0).abs() < 0.001);
1856 }
1857
1858 #[test]
1859 fn test_multiple_group_by_keys_reduce_cardinality() {
1860 let mut estimator = CardinalityEstimator::new();
1864 estimator.add_table_stats("Person", TableStats::new(10000));
1865
1866 let single_group = LogicalOperator::Aggregate(AggregateOp {
1867 group_by: vec![LogicalExpression::Property {
1868 variable: "n".to_string(),
1869 property: "city".to_string(),
1870 }],
1871 aggregates: vec![],
1872 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1873 variable: "n".to_string(),
1874 label: Some("Person".to_string()),
1875 input: None,
1876 })),
1877 having: None,
1878 });
1879
1880 let multi_group = LogicalOperator::Aggregate(AggregateOp {
1881 group_by: vec![
1882 LogicalExpression::Property {
1883 variable: "n".to_string(),
1884 property: "city".to_string(),
1885 },
1886 LogicalExpression::Property {
1887 variable: "n".to_string(),
1888 property: "country".to_string(),
1889 },
1890 ],
1891 aggregates: vec![],
1892 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1893 variable: "n".to_string(),
1894 label: Some("Person".to_string()),
1895 input: None,
1896 })),
1897 having: None,
1898 });
1899
1900 let single_card = estimator.estimate(&single_group);
1901 let multi_card = estimator.estimate(&multi_group);
1902
1903 assert!(single_card < 10000.0);
1905 assert!(multi_card < 10000.0);
1906 assert!(single_card >= 1.0);
1908 assert!(multi_card >= 1.0);
1909 }
1910
1911 #[test]
1914 fn test_histogram_build_uniform() {
1915 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1917 let histogram = EquiDepthHistogram::build(&values, 10);
1918
1919 assert_eq!(histogram.num_buckets(), 10);
1920 assert_eq!(histogram.total_rows(), 100);
1921
1922 for bucket in histogram.buckets() {
1924 assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1925 }
1926 }
1927
1928 #[test]
1929 fn test_histogram_build_skewed() {
1930 let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1932 values.extend((0..20).map(|i| 1000.0 + i as f64));
1933 let histogram = EquiDepthHistogram::build(&values, 5);
1934
1935 assert_eq!(histogram.num_buckets(), 5);
1936 assert_eq!(histogram.total_rows(), 100);
1937
1938 for bucket in histogram.buckets() {
1940 assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1941 }
1942 }
1943
1944 #[test]
1945 fn test_histogram_range_selectivity_full() {
1946 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1947 let histogram = EquiDepthHistogram::build(&values, 10);
1948
1949 let selectivity = histogram.range_selectivity(None, None);
1951 assert!((selectivity - 1.0).abs() < 0.01);
1952 }
1953
1954 #[test]
1955 fn test_histogram_range_selectivity_half() {
1956 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1957 let histogram = EquiDepthHistogram::build(&values, 10);
1958
1959 let selectivity = histogram.range_selectivity(Some(50.0), None);
1961 assert!(selectivity > 0.4 && selectivity < 0.6);
1962 }
1963
1964 #[test]
1965 fn test_histogram_range_selectivity_quarter() {
1966 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1967 let histogram = EquiDepthHistogram::build(&values, 10);
1968
1969 let selectivity = histogram.range_selectivity(None, Some(25.0));
1971 assert!(selectivity > 0.2 && selectivity < 0.3);
1972 }
1973
1974 #[test]
1975 fn test_histogram_equality_selectivity() {
1976 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1977 let histogram = EquiDepthHistogram::build(&values, 10);
1978
1979 let selectivity = histogram.equality_selectivity(50.0);
1981 assert!(selectivity > 0.005 && selectivity < 0.02);
1982 }
1983
1984 #[test]
1985 fn test_histogram_empty() {
1986 let histogram = EquiDepthHistogram::build(&[], 10);
1987
1988 assert_eq!(histogram.num_buckets(), 0);
1989 assert_eq!(histogram.total_rows(), 0);
1990
1991 let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1993 assert!((selectivity - 0.33).abs() < 0.01);
1994 }
1995
1996 #[test]
1997 fn test_histogram_bucket_overlap() {
1998 let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1999
2000 assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
2002
2003 assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
2005
2006 assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
2008
2009 assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
2011
2012 assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
2014 }
2015
2016 #[test]
2017 fn test_column_stats_from_values() {
2018 let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2019 let stats = ColumnStats::from_values(values, 4);
2020
2021 assert_eq!(stats.distinct_count, 5); assert!(stats.min_value.is_some());
2023 assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2024 assert!(stats.max_value.is_some());
2025 assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2026 assert!(stats.histogram.is_some());
2027 }
2028
2029 #[test]
2030 fn test_column_stats_with_histogram_builder() {
2031 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2032 let histogram = EquiDepthHistogram::build(&values, 10);
2033
2034 let stats = ColumnStats::new(100)
2035 .with_range(0.0, 99.0)
2036 .with_histogram(histogram);
2037
2038 assert!(stats.histogram.is_some());
2039 assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2040 }
2041
2042 #[test]
2043 fn test_filter_with_histogram_stats() {
2044 let mut estimator = CardinalityEstimator::new();
2045
2046 let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2048 let histogram = EquiDepthHistogram::build(&age_values, 10);
2049 let age_stats = ColumnStats::new(62)
2050 .with_range(18.0, 79.0)
2051 .with_histogram(histogram);
2052
2053 estimator.add_table_stats(
2054 "Person",
2055 TableStats::new(1000).with_column("age", age_stats),
2056 );
2057
2058 let filter = LogicalOperator::Filter(FilterOp {
2061 predicate: LogicalExpression::Binary {
2062 left: Box::new(LogicalExpression::Property {
2063 variable: "n".to_string(),
2064 property: "age".to_string(),
2065 }),
2066 op: BinaryOp::Gt,
2067 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2068 },
2069 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2070 variable: "n".to_string(),
2071 label: Some("Person".to_string()),
2072 input: None,
2073 })),
2074 pushdown_hint: None,
2075 });
2076
2077 let cardinality = estimator.estimate(&filter);
2078
2079 assert!(cardinality > 300.0 && cardinality < 600.0);
2082 }
2083
2084 #[test]
2085 fn test_filter_equality_with_histogram() {
2086 let mut estimator = CardinalityEstimator::new();
2087
2088 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2090 let histogram = EquiDepthHistogram::build(&values, 10);
2091 let stats = ColumnStats::new(100)
2092 .with_range(0.0, 99.0)
2093 .with_histogram(histogram);
2094
2095 estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2096
2097 let filter = LogicalOperator::Filter(FilterOp {
2099 predicate: LogicalExpression::Binary {
2100 left: Box::new(LogicalExpression::Property {
2101 variable: "d".to_string(),
2102 property: "value".to_string(),
2103 }),
2104 op: BinaryOp::Eq,
2105 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2106 },
2107 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2108 variable: "d".to_string(),
2109 label: Some("Data".to_string()),
2110 input: None,
2111 })),
2112 pushdown_hint: None,
2113 });
2114
2115 let cardinality = estimator.estimate(&filter);
2116
2117 assert!((1.0..50.0).contains(&cardinality));
2120 }
2121
2122 #[test]
2123 fn test_histogram_min_max() {
2124 let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2125 let histogram = EquiDepthHistogram::build(&values, 2);
2126
2127 assert_eq!(histogram.min_value(), Some(5.0));
2128 assert!(histogram.max_value().is_some());
2130 }
2131
2132 #[test]
2135 fn test_selectivity_config_defaults() {
2136 let config = SelectivityConfig::new();
2137 assert!((config.default - 0.1).abs() < f64::EPSILON);
2138 assert!((config.equality - 0.01).abs() < f64::EPSILON);
2139 assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2140 assert!((config.range - 0.33).abs() < f64::EPSILON);
2141 assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2142 assert!((config.membership - 0.1).abs() < f64::EPSILON);
2143 assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2144 assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2145 assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2146 }
2147
2148 #[test]
2149 fn test_custom_selectivity_config() {
2150 let config = SelectivityConfig {
2151 equality: 0.05,
2152 range: 0.25,
2153 ..SelectivityConfig::new()
2154 };
2155 let estimator = CardinalityEstimator::with_selectivity_config(config);
2156 assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2157 assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2158 }
2159
2160 #[test]
2161 fn test_custom_selectivity_affects_estimation() {
2162 let mut default_est = CardinalityEstimator::new();
2164 default_est.add_table_stats("Person", TableStats::new(1000));
2165
2166 let filter = LogicalOperator::Filter(FilterOp {
2167 predicate: LogicalExpression::Binary {
2168 left: Box::new(LogicalExpression::Property {
2169 variable: "n".to_string(),
2170 property: "name".to_string(),
2171 }),
2172 op: BinaryOp::Eq,
2173 right: Box::new(LogicalExpression::Literal(Value::String("Alix".into()))),
2174 },
2175 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2176 variable: "n".to_string(),
2177 label: Some("Person".to_string()),
2178 input: None,
2179 })),
2180 pushdown_hint: None,
2181 });
2182
2183 let default_card = default_est.estimate(&filter);
2184
2185 let config = SelectivityConfig {
2187 equality: 0.2,
2188 ..SelectivityConfig::new()
2189 };
2190 let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2191 custom_est.add_table_stats("Person", TableStats::new(1000));
2192
2193 let custom_card = custom_est.estimate(&filter);
2194
2195 assert!(custom_card > default_card);
2196 assert!((custom_card - 200.0).abs() < 1.0);
2197 }
2198
2199 #[test]
2200 fn test_custom_range_selectivity() {
2201 let config = SelectivityConfig {
2202 range: 0.5,
2203 ..SelectivityConfig::new()
2204 };
2205 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2206 estimator.add_table_stats("Person", TableStats::new(1000));
2207
2208 let filter = LogicalOperator::Filter(FilterOp {
2209 predicate: LogicalExpression::Binary {
2210 left: Box::new(LogicalExpression::Property {
2211 variable: "n".to_string(),
2212 property: "age".to_string(),
2213 }),
2214 op: BinaryOp::Gt,
2215 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2216 },
2217 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2218 variable: "n".to_string(),
2219 label: Some("Person".to_string()),
2220 input: None,
2221 })),
2222 pushdown_hint: None,
2223 });
2224
2225 let cardinality = estimator.estimate(&filter);
2226 assert!((cardinality - 500.0).abs() < 1.0);
2228 }
2229
2230 #[test]
2231 fn test_custom_distinct_fraction() {
2232 let config = SelectivityConfig {
2233 distinct_fraction: 0.8,
2234 ..SelectivityConfig::new()
2235 };
2236 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2237 estimator.add_table_stats("Person", TableStats::new(1000));
2238
2239 let distinct = LogicalOperator::Distinct(DistinctOp {
2240 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2241 variable: "n".to_string(),
2242 label: Some("Person".to_string()),
2243 input: None,
2244 })),
2245 columns: None,
2246 });
2247
2248 let cardinality = estimator.estimate(&distinct);
2249 assert!((cardinality - 800.0).abs() < 1.0);
2251 }
2252
2253 #[test]
2256 fn test_estimation_log_basic() {
2257 let mut log = EstimationLog::new(10.0);
2258 log.record("NodeScan(Person)", 1000.0, 1200.0);
2259 log.record("Filter(age > 30)", 100.0, 90.0);
2260
2261 assert_eq!(log.entries().len(), 2);
2262 assert!(!log.should_replan()); }
2264
2265 #[test]
2266 fn test_estimation_log_triggers_replan() {
2267 let mut log = EstimationLog::new(10.0);
2268 log.record("NodeScan(Person)", 100.0, 5000.0); assert!(log.should_replan());
2271 }
2272
2273 #[test]
2274 fn test_estimation_log_overestimate_triggers_replan() {
2275 let mut log = EstimationLog::new(5.0);
2276 log.record("Filter", 1000.0, 100.0); assert!(log.should_replan()); }
2280
2281 #[test]
2282 fn test_estimation_entry_error_ratio() {
2283 let entry = EstimationEntry {
2284 operator: "test".into(),
2285 estimated: 100.0,
2286 actual: 200.0,
2287 };
2288 assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2289
2290 let perfect = EstimationEntry {
2291 operator: "test".into(),
2292 estimated: 100.0,
2293 actual: 100.0,
2294 };
2295 assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2296
2297 let zero_est = EstimationEntry {
2298 operator: "test".into(),
2299 estimated: 0.0,
2300 actual: 0.0,
2301 };
2302 assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2303 }
2304
2305 #[test]
2306 fn test_estimation_log_max_error_ratio() {
2307 let mut log = EstimationLog::new(10.0);
2308 log.record("A", 100.0, 300.0); log.record("B", 100.0, 50.0); log.record("C", 100.0, 100.0); assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2313 }
2314
2315 #[test]
2316 fn test_estimation_log_clear() {
2317 let mut log = EstimationLog::new(10.0);
2318 log.record("A", 100.0, 100.0);
2319 assert_eq!(log.entries().len(), 1);
2320
2321 log.clear();
2322 assert!(log.entries().is_empty());
2323 assert!(!log.should_replan());
2324 }
2325
2326 #[test]
2327 fn test_create_estimation_log() {
2328 let log = CardinalityEstimator::create_estimation_log();
2329 assert!(log.entries().is_empty());
2330 assert!(!log.should_replan());
2331 }
2332}