1use crate::query::plan::{
18 AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp, LimitOp,
19 LogicalExpression, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, SkipOp, SortOp,
20 TextScanOp, TripleComponent, TripleScanOp, UnaryOp, VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31 pub lower_bound: f64,
33 pub upper_bound: f64,
35 pub frequency: u64,
37 pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42 #[must_use]
44 pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45 Self {
46 lower_bound,
47 upper_bound,
48 frequency,
49 distinct_count,
50 }
51 }
52
53 #[must_use]
55 pub fn width(&self) -> f64 {
56 self.upper_bound - self.lower_bound
57 }
58
59 #[must_use]
61 pub fn contains(&self, value: f64) -> bool {
62 value >= self.lower_bound && value < self.upper_bound
63 }
64
65 #[must_use]
67 pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68 let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69 let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71 let bucket_width = self.width();
72 if bucket_width <= 0.0 {
73 return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74 1.0
75 } else {
76 0.0
77 };
78 }
79
80 let overlap = (effective_upper - effective_lower).max(0.0);
81 (overlap / bucket_width).min(1.0)
82 }
83}
84
85#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106 buckets: Vec<HistogramBucket>,
108 total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113 #[must_use]
115 pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116 let total_rows = buckets.iter().map(|b| b.frequency).sum();
117 Self {
118 buckets,
119 total_rows,
120 }
121 }
122
123 #[must_use]
132 pub fn build(values: &[f64], num_buckets: usize) -> Self {
133 if values.is_empty() || num_buckets == 0 {
134 return Self {
135 buckets: Vec::new(),
136 total_rows: 0,
137 };
138 }
139
140 let num_buckets = num_buckets.min(values.len());
141 let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142 let mut buckets = Vec::with_capacity(num_buckets);
143
144 let mut start_idx = 0;
145 while start_idx < values.len() {
146 let end_idx = (start_idx + rows_per_bucket).min(values.len());
147 let lower_bound = values[start_idx];
148 let upper_bound = if end_idx < values.len() {
149 values[end_idx]
150 } else {
151 values[end_idx - 1] + 1.0
153 };
154
155 let bucket_values = &values[start_idx..end_idx];
157 let distinct_count = count_distinct(bucket_values);
158
159 buckets.push(HistogramBucket::new(
160 lower_bound,
161 upper_bound,
162 (end_idx - start_idx) as u64,
163 distinct_count,
164 ));
165
166 start_idx = end_idx;
167 }
168
169 Self::new(buckets)
170 }
171
172 #[must_use]
174 pub fn num_buckets(&self) -> usize {
175 self.buckets.len()
176 }
177
178 #[must_use]
180 pub fn total_rows(&self) -> u64 {
181 self.total_rows
182 }
183
184 #[must_use]
186 pub fn buckets(&self) -> &[HistogramBucket] {
187 &self.buckets
188 }
189
190 #[must_use]
199 pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200 if self.buckets.is_empty() || self.total_rows == 0 {
201 return 0.33; }
203
204 let mut matching_rows = 0.0;
205
206 for bucket in &self.buckets {
207 let bucket_lower = bucket.lower_bound;
209 let bucket_upper = bucket.upper_bound;
210
211 if let Some(l) = lower
213 && bucket_upper <= l
214 {
215 continue;
216 }
217 if let Some(u) = upper
218 && bucket_lower >= u
219 {
220 continue;
221 }
222
223 let overlap = bucket.overlap_fraction(lower, upper);
225 matching_rows += overlap * bucket.frequency as f64;
226 }
227
228 (matching_rows / self.total_rows as f64).clamp(0.0, 1.0)
229 }
230
231 #[must_use]
235 pub fn equality_selectivity(&self, value: f64) -> f64 {
236 if self.buckets.is_empty() || self.total_rows == 0 {
237 return 0.01; }
239
240 for bucket in &self.buckets {
242 if bucket.contains(value) {
243 if bucket.distinct_count > 0 {
245 return (bucket.frequency as f64
246 / bucket.distinct_count as f64
247 / self.total_rows as f64)
248 .min(1.0);
249 }
250 }
251 }
252
253 0.001
255 }
256
257 #[must_use]
259 pub fn min_value(&self) -> Option<f64> {
260 self.buckets.first().map(|b| b.lower_bound)
261 }
262
263 #[must_use]
265 pub fn max_value(&self) -> Option<f64> {
266 self.buckets.last().map(|b| b.upper_bound)
267 }
268}
269
270fn count_and_conjuncts(expr: &LogicalExpression) -> usize {
276 match expr {
277 LogicalExpression::Binary {
278 op: BinaryOp::And,
279 left,
280 right,
281 } => count_and_conjuncts(left) + count_and_conjuncts(right),
282 _ => 1,
283 }
284}
285
286fn count_distinct(sorted_values: &[f64]) -> u64 {
287 if sorted_values.is_empty() {
288 return 0;
289 }
290
291 let mut count = 1u64;
292 let mut prev = sorted_values[0];
293
294 for &val in &sorted_values[1..] {
295 if (val - prev).abs() > f64::EPSILON {
296 count += 1;
297 prev = val;
298 }
299 }
300
301 count
302}
303
304#[derive(Debug, Clone)]
306pub struct TableStats {
307 pub row_count: u64,
309 pub columns: HashMap<String, ColumnStats>,
311}
312
313impl TableStats {
314 #[must_use]
316 pub fn new(row_count: u64) -> Self {
317 Self {
318 row_count,
319 columns: HashMap::new(),
320 }
321 }
322
323 pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
325 self.columns.insert(name.to_string(), stats);
326 self
327 }
328}
329
330#[derive(Debug, Clone)]
332pub struct ColumnStats {
333 pub distinct_count: u64,
335 pub null_count: u64,
337 pub min_value: Option<f64>,
339 pub max_value: Option<f64>,
341 pub histogram: Option<EquiDepthHistogram>,
343}
344
345impl ColumnStats {
346 #[must_use]
348 pub fn new(distinct_count: u64) -> Self {
349 Self {
350 distinct_count,
351 null_count: 0,
352 min_value: None,
353 max_value: None,
354 histogram: None,
355 }
356 }
357
358 #[must_use]
360 pub fn with_nulls(mut self, null_count: u64) -> Self {
361 self.null_count = null_count;
362 self
363 }
364
365 #[must_use]
367 pub fn with_range(mut self, min: f64, max: f64) -> Self {
368 self.min_value = Some(min);
369 self.max_value = Some(max);
370 self
371 }
372
373 #[must_use]
375 pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
376 self.histogram = Some(histogram);
377 self
378 }
379
380 #[must_use]
388 pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
389 if values.is_empty() {
390 return Self::new(0);
391 }
392
393 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
395
396 let min = values.first().copied();
397 let max = values.last().copied();
398 let distinct_count = count_distinct(&values);
399 let histogram = EquiDepthHistogram::build(&values, num_buckets);
400
401 Self {
402 distinct_count,
403 null_count: 0,
404 min_value: min,
405 max_value: max,
406 histogram: Some(histogram),
407 }
408 }
409}
410
411#[derive(Debug, Clone)]
417pub struct SelectivityConfig {
418 pub default: f64,
420 pub equality: f64,
422 pub inequality: f64,
424 pub range: f64,
426 pub string_ops: f64,
428 pub membership: f64,
430 pub is_null: f64,
432 pub is_not_null: f64,
434 pub distinct_fraction: f64,
436}
437
438impl SelectivityConfig {
439 #[must_use]
441 pub fn new() -> Self {
442 Self {
443 default: 0.1,
444 equality: 0.01,
445 inequality: 0.99,
446 range: 0.33,
447 string_ops: 0.1,
448 membership: 0.1,
449 is_null: 0.05,
450 is_not_null: 0.95,
451 distinct_fraction: 0.5,
452 }
453 }
454}
455
456impl Default for SelectivityConfig {
457 fn default() -> Self {
458 Self::new()
459 }
460}
461
462#[derive(Debug, Clone)]
464pub struct EstimationEntry {
465 pub operator: String,
467 pub estimated: f64,
469 pub actual: f64,
471}
472
473impl EstimationEntry {
474 #[must_use]
480 pub fn error_ratio(&self) -> f64 {
481 if self.estimated.abs() < f64::EPSILON {
482 if self.actual.abs() < f64::EPSILON {
483 1.0
484 } else {
485 f64::INFINITY
486 }
487 } else {
488 self.actual / self.estimated
489 }
490 }
491}
492
493#[derive(Debug, Clone, Default)]
500pub struct EstimationLog {
501 entries: Vec<EstimationEntry>,
503 replan_threshold: f64,
507}
508
509impl EstimationLog {
510 #[must_use]
512 pub fn new(replan_threshold: f64) -> Self {
513 Self {
514 entries: Vec::new(),
515 replan_threshold,
516 }
517 }
518
519 pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
521 self.entries.push(EstimationEntry {
522 operator: operator.into(),
523 estimated,
524 actual,
525 });
526 }
527
528 #[must_use]
530 pub fn entries(&self) -> &[EstimationEntry] {
531 &self.entries
532 }
533
534 #[must_use]
537 pub fn should_replan(&self) -> bool {
538 self.entries.iter().any(|e| {
539 let ratio = e.error_ratio();
540 ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
541 })
542 }
543
544 #[must_use]
546 pub fn max_error_ratio(&self) -> f64 {
547 self.entries
548 .iter()
549 .map(|e| {
550 let r = e.error_ratio();
551 if r < 1.0 { 1.0 / r } else { r }
553 })
554 .fold(1.0_f64, f64::max)
555 }
556
557 pub fn clear(&mut self) {
559 self.entries.clear();
560 }
561}
562
563pub struct CardinalityEstimator {
565 table_stats: HashMap<String, TableStats>,
567 default_row_count: u64,
569 default_selectivity: f64,
571 avg_fanout: f64,
573 selectivity_config: SelectivityConfig,
575 rdf_statistics: Option<grafeo_core::statistics::RdfStatistics>,
577}
578
579impl CardinalityEstimator {
580 #[must_use]
582 pub fn new() -> Self {
583 let config = SelectivityConfig::new();
584 Self {
585 table_stats: HashMap::new(),
586 default_row_count: 1000,
587 default_selectivity: config.default,
588 avg_fanout: 10.0,
589 selectivity_config: config,
590 rdf_statistics: None,
591 }
592 }
593
594 #[must_use]
596 pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
597 Self {
598 table_stats: HashMap::new(),
599 default_row_count: 1000,
600 default_selectivity: config.default,
601 avg_fanout: 10.0,
602 selectivity_config: config,
603 rdf_statistics: None,
604 }
605 }
606
607 #[must_use]
609 pub fn selectivity_config(&self) -> &SelectivityConfig {
610 &self.selectivity_config
611 }
612
613 #[must_use]
615 pub fn create_estimation_log() -> EstimationLog {
616 EstimationLog::new(10.0)
617 }
618
619 #[must_use]
625 pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
626 let mut estimator = Self::new();
627
628 if stats.total_nodes > 0 {
630 estimator.default_row_count = stats.total_nodes;
631 }
632
633 for (label, label_stats) in &stats.labels {
635 let mut table_stats = TableStats::new(label_stats.node_count);
636
637 for (prop, col_stats) in &label_stats.properties {
639 let optimizer_col =
640 ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
641 table_stats = table_stats.with_column(prop, optimizer_col);
642 }
643
644 estimator.add_table_stats(label, table_stats);
645 }
646
647 if !stats.edge_types.is_empty() {
649 let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
650 estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
651 } else if stats.total_nodes > 0 {
652 estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
653 }
654
655 if estimator.avg_fanout < 1.0 {
657 estimator.avg_fanout = 1.0;
658 }
659
660 estimator
661 }
662
663 #[must_use]
668 pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
669 let mut estimator = Self::new();
670 if rdf_stats.total_triples > 0 {
671 estimator.default_row_count = rdf_stats.total_triples;
672 }
673 estimator.rdf_statistics = Some(rdf_stats);
674 estimator
675 }
676
677 pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
679 self.table_stats.insert(name.to_string(), stats);
680 }
681
682 pub fn set_avg_fanout(&mut self, fanout: f64) {
684 self.avg_fanout = fanout;
685 }
686
687 #[must_use]
689 pub fn estimate(&self, op: &LogicalOperator) -> f64 {
690 match op {
691 LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
692 LogicalOperator::Filter(filter) => self.estimate_filter(filter),
693 LogicalOperator::Project(project) => self.estimate_project(project),
694 LogicalOperator::Expand(expand) => self.estimate_expand(expand),
695 LogicalOperator::Join(join) => self.estimate_join(join),
696 LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
697 LogicalOperator::Sort(sort) => self.estimate_sort(sort),
698 LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
699 LogicalOperator::Limit(limit) => self.estimate_limit(limit),
700 LogicalOperator::Skip(skip) => self.estimate_skip(skip),
701 LogicalOperator::Return(ret) => self.estimate(&ret.input),
702 LogicalOperator::Empty => 0.0,
703 LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
704 LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
705 LogicalOperator::MultiWayJoin(mwj) => self.estimate_multi_way_join(mwj),
706 LogicalOperator::LeftJoin(lj) => self.estimate_left_join(lj),
707 LogicalOperator::TripleScan(scan) => self.estimate_triple_scan(scan),
708 LogicalOperator::TextScan(scan) => self.estimate_text_scan(scan),
709 _ => self.default_row_count as f64,
710 }
711 }
712
713 fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
715 if let Some(label) = &scan.label
716 && let Some(stats) = self.table_stats.get(label)
717 {
718 return stats.row_count as f64;
719 }
720 self.default_row_count as f64
722 }
723
724 fn estimate_triple_scan(&self, scan: &TripleScanOp) -> f64 {
730 let base = if let Some(ref input) = scan.input {
733 self.estimate(input)
734 } else {
735 1.0
736 };
737
738 let Some(rdf_stats) = &self.rdf_statistics else {
739 return if scan.input.is_some() {
740 base * self.default_row_count as f64
741 } else {
742 self.default_row_count as f64
743 };
744 };
745
746 let subject_bound = matches!(
747 scan.subject,
748 TripleComponent::Iri(_)
749 | TripleComponent::Literal(_)
750 | TripleComponent::LangLiteral { .. }
751 );
752 let object_bound = matches!(
753 scan.object,
754 TripleComponent::Iri(_)
755 | TripleComponent::Literal(_)
756 | TripleComponent::LangLiteral { .. }
757 );
758 let predicate_iri = match &scan.predicate {
759 TripleComponent::Iri(iri) => Some(iri.as_str()),
760 _ => None,
761 };
762
763 let pattern_card = rdf_stats.estimate_triple_pattern_cardinality(
764 subject_bound,
765 predicate_iri,
766 object_bound,
767 );
768
769 if scan.input.is_some() {
770 let selectivity = if rdf_stats.total_triples > 0 {
772 pattern_card / rdf_stats.total_triples as f64
773 } else {
774 1.0
775 };
776 (base * pattern_card * selectivity).max(1.0)
777 } else {
778 pattern_card.max(1.0)
779 }
780 }
781
782 fn estimate_filter(&self, filter: &FilterOp) -> f64 {
784 let input_cardinality = self.estimate(&filter.input);
785 let selectivity = self.estimate_selectivity(&filter.predicate);
786 (input_cardinality * selectivity).max(1.0)
787 }
788
789 fn estimate_project(&self, project: &ProjectOp) -> f64 {
791 self.estimate(&project.input)
792 }
793
794 fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
796 let input_cardinality = self.estimate(&expand.input);
797
798 let fanout = if !expand.edge_types.is_empty() {
800 self.avg_fanout * 0.5
802 } else {
803 self.avg_fanout
804 };
805
806 let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
808 let min = expand.min_hops as f64;
809 let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
810 (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
812 } else {
813 fanout
814 };
815
816 (input_cardinality * path_multiplier).max(1.0)
817 }
818
819 fn estimate_join(&self, join: &JoinOp) -> f64 {
821 let left_card = self.estimate(&join.left);
822 let right_card = self.estimate(&join.right);
823
824 match join.join_type {
825 JoinType::Cross => left_card * right_card,
826 JoinType::Inner => {
827 let selectivity = if join.conditions.is_empty() {
829 1.0 } else {
831 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
834 let exp = join.conditions.len() as i32;
835 0.1_f64.powi(exp)
836 };
837 (left_card * right_card * selectivity).max(1.0)
838 }
839 JoinType::Left => {
840 let inner_card = self.estimate_join(&JoinOp {
842 left: join.left.clone(),
843 right: join.right.clone(),
844 join_type: JoinType::Inner,
845 conditions: join.conditions.clone(),
846 });
847 inner_card.max(left_card)
848 }
849 JoinType::Right => {
850 let inner_card = self.estimate_join(&JoinOp {
852 left: join.left.clone(),
853 right: join.right.clone(),
854 join_type: JoinType::Inner,
855 conditions: join.conditions.clone(),
856 });
857 inner_card.max(right_card)
858 }
859 JoinType::Full => {
860 let inner_card = self.estimate_join(&JoinOp {
862 left: join.left.clone(),
863 right: join.right.clone(),
864 join_type: JoinType::Inner,
865 conditions: join.conditions.clone(),
866 });
867 inner_card.max(left_card.max(right_card))
868 }
869 JoinType::Semi => {
870 (left_card * self.default_selectivity).max(1.0)
872 }
873 JoinType::Anti => {
874 (left_card * (1.0 - self.default_selectivity)).max(1.0)
876 }
877 }
878 }
879
880 fn estimate_left_join(&self, lj: &LeftJoinOp) -> f64 {
889 let left_card = self.estimate(&lj.left);
890 let right_card = self.estimate(&lj.right);
891
892 let condition_selectivity = if let Some(cond) = &lj.condition {
895 let n = count_and_conjuncts(cond);
896 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
898 let exp = n as i32;
899 self.default_selectivity.powi(exp)
900 } else {
901 self.default_selectivity
902 };
903
904 let inner_estimate = left_card * right_card * condition_selectivity;
906 inner_estimate.max(left_card).max(1.0)
907 }
908
909 fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
911 let input_cardinality = self.estimate(&agg.input);
912
913 if agg.group_by.is_empty() {
914 1.0
916 } else {
917 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
921 let exp = agg.group_by.len() as i32;
922 let group_reduction = 10.0_f64.powi(exp);
923 (input_cardinality / group_reduction).max(1.0)
924 }
925 }
926
927 fn estimate_sort(&self, sort: &SortOp) -> f64 {
929 self.estimate(&sort.input)
930 }
931
932 fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
934 let input_cardinality = self.estimate(&distinct.input);
935 (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
936 }
937
938 fn estimate_limit(&self, limit: &LimitOp) -> f64 {
940 let input_cardinality = self.estimate(&limit.input);
941 limit.count.estimate().min(input_cardinality)
942 }
943
944 fn estimate_skip(&self, skip: &SkipOp) -> f64 {
946 let input_cardinality = self.estimate(&skip.input);
947 (input_cardinality - skip.count.estimate()).max(0.0)
948 }
949
950 fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
955 if let Some(k) = scan.k {
956 let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
958 0.7 } else {
960 1.0
961 };
962 (k as f64 * selectivity).max(1.0)
963 } else {
964 let base = scan
966 .label
967 .as_deref()
968 .and_then(|l| self.table_stats.get(l))
969 .map_or(self.default_row_count as f64, |s| s.row_count as f64);
970 (base * 0.2).max(1.0)
971 }
972 }
973
974 fn estimate_text_scan(&self, scan: &TextScanOp) -> f64 {
979 if let Some(k) = scan.k {
980 return k as f64;
982 }
983 if scan.threshold.is_some() {
984 let default_selectivity = 0.1;
986 let base = if let Some(stats) = self.table_stats.get(&scan.label) {
987 stats.row_count as f64
988 } else {
989 self.default_row_count as f64
990 };
991 return (base * default_selectivity).max(1.0);
992 }
993 100.0
995 }
996
997 fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
1001 let input_cardinality = self.estimate(&join.input);
1002 let k = join.k as f64;
1003
1004 let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
1006 0.7
1007 } else {
1008 1.0
1009 };
1010
1011 (input_cardinality * k * selectivity).max(1.0)
1012 }
1013
1014 fn estimate_multi_way_join(&self, mwj: &MultiWayJoinOp) -> f64 {
1019 if mwj.inputs.is_empty() {
1020 return 0.0;
1021 }
1022 let cardinalities: Vec<f64> = mwj
1023 .inputs
1024 .iter()
1025 .map(|input| self.estimate(input))
1026 .collect();
1027 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
1028 let n = cardinalities.len() as f64;
1029 (min_card.powf(n / 2.0)).max(1.0)
1031 }
1032
1033 fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
1035 match expr {
1036 LogicalExpression::Binary { left, op, right } => {
1037 self.estimate_binary_selectivity(left, *op, right)
1038 }
1039 LogicalExpression::Unary { op, operand } => {
1040 self.estimate_unary_selectivity(*op, operand)
1041 }
1042 LogicalExpression::Literal(value) => {
1043 if let grafeo_common::types::Value::Bool(b) = value {
1045 if *b { 1.0 } else { 0.0 }
1046 } else {
1047 self.default_selectivity
1048 }
1049 }
1050 _ => self.default_selectivity,
1051 }
1052 }
1053
1054 fn estimate_binary_selectivity(
1056 &self,
1057 left: &LogicalExpression,
1058 op: BinaryOp,
1059 right: &LogicalExpression,
1060 ) -> f64 {
1061 match op {
1062 BinaryOp::Eq => {
1064 if let Some(selectivity) = self.try_equality_selectivity(left, right) {
1065 return selectivity;
1066 }
1067 self.selectivity_config.equality
1068 }
1069 BinaryOp::Ne => self.selectivity_config.inequality,
1071 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
1073 if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
1074 return selectivity;
1075 }
1076 self.selectivity_config.range
1077 }
1078 BinaryOp::And => {
1080 let left_sel = self.estimate_selectivity(left);
1081 let right_sel = self.estimate_selectivity(right);
1082 left_sel * right_sel
1084 }
1085 BinaryOp::Or => {
1086 let left_sel = self.estimate_selectivity(left);
1087 let right_sel = self.estimate_selectivity(right);
1088 (left_sel + right_sel - left_sel * right_sel).min(1.0)
1091 }
1092 BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
1094 self.selectivity_config.string_ops
1095 }
1096 BinaryOp::In => self.selectivity_config.membership,
1098 _ => self.default_selectivity,
1100 }
1101 }
1102
1103 fn try_equality_selectivity(
1105 &self,
1106 left: &LogicalExpression,
1107 right: &LogicalExpression,
1108 ) -> Option<f64> {
1109 let (label, column, value) = self.extract_column_and_value(left, right)?;
1111
1112 let stats = self.get_column_stats(&label, &column)?;
1114
1115 if let Some(ref histogram) = stats.histogram {
1117 return Some(histogram.equality_selectivity(value));
1118 }
1119
1120 if stats.distinct_count > 0 {
1122 return Some(1.0 / stats.distinct_count as f64);
1123 }
1124
1125 None
1126 }
1127
1128 fn try_range_selectivity(
1130 &self,
1131 left: &LogicalExpression,
1132 op: BinaryOp,
1133 right: &LogicalExpression,
1134 ) -> Option<f64> {
1135 let (label, column, value) = self.extract_column_and_value(left, right)?;
1137
1138 let stats = self.get_column_stats(&label, &column)?;
1140
1141 let (lower, upper) = match op {
1143 BinaryOp::Lt => (None, Some(value)),
1144 BinaryOp::Le => (None, Some(value + f64::EPSILON)),
1145 BinaryOp::Gt => (Some(value + f64::EPSILON), None),
1146 BinaryOp::Ge => (Some(value), None),
1147 _ => return None,
1148 };
1149
1150 if let Some(ref histogram) = stats.histogram {
1152 return Some(histogram.range_selectivity(lower, upper));
1153 }
1154
1155 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
1157 let range = max - min;
1158 if range <= 0.0 {
1159 return Some(1.0);
1160 }
1161
1162 let effective_lower = lower.unwrap_or(min).max(min);
1163 let effective_upper = upper.unwrap_or(max).min(max);
1164 let overlap = (effective_upper - effective_lower).max(0.0);
1165 return Some((overlap / range).clamp(0.0, 1.0));
1166 }
1167
1168 None
1169 }
1170
1171 fn extract_column_and_value(
1176 &self,
1177 left: &LogicalExpression,
1178 right: &LogicalExpression,
1179 ) -> Option<(String, String, f64)> {
1180 if let Some(result) = self.try_extract_property_literal(left, right) {
1182 return Some(result);
1183 }
1184
1185 self.try_extract_property_literal(right, left)
1187 }
1188
1189 fn try_extract_property_literal(
1191 &self,
1192 property_expr: &LogicalExpression,
1193 literal_expr: &LogicalExpression,
1194 ) -> Option<(String, String, f64)> {
1195 let (variable, property) = match property_expr {
1197 LogicalExpression::Property { variable, property } => {
1198 (variable.clone(), property.clone())
1199 }
1200 _ => return None,
1201 };
1202
1203 let value = match literal_expr {
1205 LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1206 LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1207 _ => return None,
1208 };
1209
1210 for label in self.table_stats.keys() {
1214 if let Some(stats) = self.table_stats.get(label)
1215 && stats.columns.contains_key(&property)
1216 {
1217 return Some((label.clone(), property, value));
1218 }
1219 }
1220
1221 Some((variable, property, value))
1223 }
1224
1225 fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1227 match op {
1228 UnaryOp::Not => 1.0 - self.default_selectivity,
1229 UnaryOp::IsNull => self.selectivity_config.is_null,
1230 UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1231 UnaryOp::Neg => 1.0, }
1233 }
1234
1235 fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1237 self.table_stats.get(label)?.columns.get(column)
1238 }
1239}
1240
1241impl Default for CardinalityEstimator {
1242 fn default() -> Self {
1243 Self::new()
1244 }
1245}
1246
1247#[cfg(test)]
1248mod tests {
1249 use super::*;
1250 use crate::query::plan::{
1251 DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1252 ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1253 };
1254 use grafeo_common::types::Value;
1255
1256 #[test]
1257 fn test_node_scan_with_stats() {
1258 let mut estimator = CardinalityEstimator::new();
1259 estimator.add_table_stats("Person", TableStats::new(5000));
1260
1261 let scan = LogicalOperator::NodeScan(NodeScanOp {
1262 variable: "n".to_string(),
1263 label: Some("Person".to_string()),
1264 input: None,
1265 });
1266
1267 let cardinality = estimator.estimate(&scan);
1268 assert!((cardinality - 5000.0).abs() < 0.001);
1269 }
1270
1271 #[test]
1272 fn test_filter_reduces_cardinality() {
1273 let mut estimator = CardinalityEstimator::new();
1274 estimator.add_table_stats("Person", TableStats::new(1000));
1275
1276 let filter = LogicalOperator::Filter(FilterOp {
1277 predicate: LogicalExpression::Binary {
1278 left: Box::new(LogicalExpression::Property {
1279 variable: "n".to_string(),
1280 property: "age".to_string(),
1281 }),
1282 op: BinaryOp::Eq,
1283 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1284 },
1285 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1286 variable: "n".to_string(),
1287 label: Some("Person".to_string()),
1288 input: None,
1289 })),
1290 pushdown_hint: None,
1291 });
1292
1293 let cardinality = estimator.estimate(&filter);
1294 assert!(cardinality < 1000.0);
1296 assert!(cardinality >= 1.0);
1297 }
1298
1299 #[test]
1300 fn test_join_cardinality() {
1301 let mut estimator = CardinalityEstimator::new();
1302 estimator.add_table_stats("Person", TableStats::new(1000));
1303 estimator.add_table_stats("Company", TableStats::new(100));
1304
1305 let join = LogicalOperator::Join(JoinOp {
1306 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1307 variable: "p".to_string(),
1308 label: Some("Person".to_string()),
1309 input: None,
1310 })),
1311 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1312 variable: "c".to_string(),
1313 label: Some("Company".to_string()),
1314 input: None,
1315 })),
1316 join_type: JoinType::Inner,
1317 conditions: vec![JoinCondition {
1318 left: LogicalExpression::Property {
1319 variable: "p".to_string(),
1320 property: "company_id".to_string(),
1321 },
1322 right: LogicalExpression::Property {
1323 variable: "c".to_string(),
1324 property: "id".to_string(),
1325 },
1326 }],
1327 });
1328
1329 let cardinality = estimator.estimate(&join);
1330 assert!(cardinality < 1000.0 * 100.0);
1332 }
1333
1334 #[test]
1335 fn test_limit_caps_cardinality() {
1336 let mut estimator = CardinalityEstimator::new();
1337 estimator.add_table_stats("Person", TableStats::new(1000));
1338
1339 let limit = LogicalOperator::Limit(LimitOp {
1340 count: 10.into(),
1341 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1342 variable: "n".to_string(),
1343 label: Some("Person".to_string()),
1344 input: None,
1345 })),
1346 });
1347
1348 let cardinality = estimator.estimate(&limit);
1349 assert!((cardinality - 10.0).abs() < 0.001);
1350 }
1351
1352 #[test]
1353 fn test_aggregate_reduces_cardinality() {
1354 let mut estimator = CardinalityEstimator::new();
1355 estimator.add_table_stats("Person", TableStats::new(1000));
1356
1357 let global_agg = LogicalOperator::Aggregate(AggregateOp {
1359 group_by: vec![],
1360 aggregates: vec![],
1361 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1362 variable: "n".to_string(),
1363 label: Some("Person".to_string()),
1364 input: None,
1365 })),
1366 having: None,
1367 });
1368
1369 let cardinality = estimator.estimate(&global_agg);
1370 assert!((cardinality - 1.0).abs() < 0.001);
1371
1372 let group_agg = LogicalOperator::Aggregate(AggregateOp {
1374 group_by: vec![LogicalExpression::Property {
1375 variable: "n".to_string(),
1376 property: "city".to_string(),
1377 }],
1378 aggregates: vec![],
1379 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1380 variable: "n".to_string(),
1381 label: Some("Person".to_string()),
1382 input: None,
1383 })),
1384 having: None,
1385 });
1386
1387 let cardinality = estimator.estimate(&group_agg);
1388 assert!(cardinality < 1000.0);
1390 }
1391
1392 #[test]
1393 fn test_node_scan_without_stats() {
1394 let estimator = CardinalityEstimator::new();
1395
1396 let scan = LogicalOperator::NodeScan(NodeScanOp {
1397 variable: "n".to_string(),
1398 label: Some("Unknown".to_string()),
1399 input: None,
1400 });
1401
1402 let cardinality = estimator.estimate(&scan);
1403 assert!((cardinality - 1000.0).abs() < 0.001);
1405 }
1406
1407 #[test]
1408 fn test_node_scan_no_label() {
1409 let estimator = CardinalityEstimator::new();
1410
1411 let scan = LogicalOperator::NodeScan(NodeScanOp {
1412 variable: "n".to_string(),
1413 label: None,
1414 input: None,
1415 });
1416
1417 let cardinality = estimator.estimate(&scan);
1418 assert!((cardinality - 1000.0).abs() < 0.001);
1420 }
1421
1422 #[test]
1423 fn test_filter_inequality_selectivity() {
1424 let mut estimator = CardinalityEstimator::new();
1425 estimator.add_table_stats("Person", TableStats::new(1000));
1426
1427 let filter = LogicalOperator::Filter(FilterOp {
1428 predicate: LogicalExpression::Binary {
1429 left: Box::new(LogicalExpression::Property {
1430 variable: "n".to_string(),
1431 property: "age".to_string(),
1432 }),
1433 op: BinaryOp::Ne,
1434 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1435 },
1436 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1437 variable: "n".to_string(),
1438 label: Some("Person".to_string()),
1439 input: None,
1440 })),
1441 pushdown_hint: None,
1442 });
1443
1444 let cardinality = estimator.estimate(&filter);
1445 assert!(cardinality > 900.0);
1447 }
1448
1449 #[test]
1450 fn test_filter_range_selectivity() {
1451 let mut estimator = CardinalityEstimator::new();
1452 estimator.add_table_stats("Person", TableStats::new(1000));
1453
1454 let filter = LogicalOperator::Filter(FilterOp {
1455 predicate: LogicalExpression::Binary {
1456 left: Box::new(LogicalExpression::Property {
1457 variable: "n".to_string(),
1458 property: "age".to_string(),
1459 }),
1460 op: BinaryOp::Gt,
1461 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1462 },
1463 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1464 variable: "n".to_string(),
1465 label: Some("Person".to_string()),
1466 input: None,
1467 })),
1468 pushdown_hint: None,
1469 });
1470
1471 let cardinality = estimator.estimate(&filter);
1472 assert!(cardinality < 500.0);
1474 assert!(cardinality > 100.0);
1475 }
1476
1477 #[test]
1478 fn test_filter_and_selectivity() {
1479 let mut estimator = CardinalityEstimator::new();
1480 estimator.add_table_stats("Person", TableStats::new(1000));
1481
1482 let filter = LogicalOperator::Filter(FilterOp {
1485 predicate: LogicalExpression::Binary {
1486 left: Box::new(LogicalExpression::Binary {
1487 left: Box::new(LogicalExpression::Property {
1488 variable: "n".to_string(),
1489 property: "city".to_string(),
1490 }),
1491 op: BinaryOp::Eq,
1492 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1493 }),
1494 op: BinaryOp::And,
1495 right: Box::new(LogicalExpression::Binary {
1496 left: Box::new(LogicalExpression::Property {
1497 variable: "n".to_string(),
1498 property: "age".to_string(),
1499 }),
1500 op: BinaryOp::Eq,
1501 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1502 }),
1503 },
1504 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1505 variable: "n".to_string(),
1506 label: Some("Person".to_string()),
1507 input: None,
1508 })),
1509 pushdown_hint: None,
1510 });
1511
1512 let cardinality = estimator.estimate(&filter);
1513 assert!(cardinality < 100.0);
1516 assert!(cardinality >= 1.0);
1517 }
1518
1519 #[test]
1520 fn test_filter_or_selectivity() {
1521 let mut estimator = CardinalityEstimator::new();
1522 estimator.add_table_stats("Person", TableStats::new(1000));
1523
1524 let filter = LogicalOperator::Filter(FilterOp {
1528 predicate: LogicalExpression::Binary {
1529 left: Box::new(LogicalExpression::Binary {
1530 left: Box::new(LogicalExpression::Property {
1531 variable: "n".to_string(),
1532 property: "city".to_string(),
1533 }),
1534 op: BinaryOp::Eq,
1535 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1536 }),
1537 op: BinaryOp::Or,
1538 right: Box::new(LogicalExpression::Binary {
1539 left: Box::new(LogicalExpression::Property {
1540 variable: "n".to_string(),
1541 property: "city".to_string(),
1542 }),
1543 op: BinaryOp::Eq,
1544 right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1545 }),
1546 },
1547 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1548 variable: "n".to_string(),
1549 label: Some("Person".to_string()),
1550 input: None,
1551 })),
1552 pushdown_hint: None,
1553 });
1554
1555 let cardinality = estimator.estimate(&filter);
1556 assert!(cardinality < 100.0);
1558 assert!(cardinality >= 1.0);
1559 }
1560
1561 #[test]
1562 fn test_filter_literal_true() {
1563 let mut estimator = CardinalityEstimator::new();
1564 estimator.add_table_stats("Person", TableStats::new(1000));
1565
1566 let filter = LogicalOperator::Filter(FilterOp {
1567 predicate: LogicalExpression::Literal(Value::Bool(true)),
1568 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1569 variable: "n".to_string(),
1570 label: Some("Person".to_string()),
1571 input: None,
1572 })),
1573 pushdown_hint: None,
1574 });
1575
1576 let cardinality = estimator.estimate(&filter);
1577 assert!((cardinality - 1000.0).abs() < 0.001);
1579 }
1580
1581 #[test]
1582 fn test_filter_literal_false() {
1583 let mut estimator = CardinalityEstimator::new();
1584 estimator.add_table_stats("Person", TableStats::new(1000));
1585
1586 let filter = LogicalOperator::Filter(FilterOp {
1587 predicate: LogicalExpression::Literal(Value::Bool(false)),
1588 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1589 variable: "n".to_string(),
1590 label: Some("Person".to_string()),
1591 input: None,
1592 })),
1593 pushdown_hint: None,
1594 });
1595
1596 let cardinality = estimator.estimate(&filter);
1597 assert!((cardinality - 1.0).abs() < 0.001);
1599 }
1600
1601 #[test]
1602 fn test_unary_not_selectivity() {
1603 let mut estimator = CardinalityEstimator::new();
1604 estimator.add_table_stats("Person", TableStats::new(1000));
1605
1606 let filter = LogicalOperator::Filter(FilterOp {
1607 predicate: LogicalExpression::Unary {
1608 op: UnaryOp::Not,
1609 operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1610 },
1611 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1612 variable: "n".to_string(),
1613 label: Some("Person".to_string()),
1614 input: None,
1615 })),
1616 pushdown_hint: None,
1617 });
1618
1619 let cardinality = estimator.estimate(&filter);
1620 assert!(cardinality < 1000.0);
1622 }
1623
1624 #[test]
1625 fn test_unary_is_null_selectivity() {
1626 let mut estimator = CardinalityEstimator::new();
1627 estimator.add_table_stats("Person", TableStats::new(1000));
1628
1629 let filter = LogicalOperator::Filter(FilterOp {
1630 predicate: LogicalExpression::Unary {
1631 op: UnaryOp::IsNull,
1632 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1633 },
1634 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1635 variable: "n".to_string(),
1636 label: Some("Person".to_string()),
1637 input: None,
1638 })),
1639 pushdown_hint: None,
1640 });
1641
1642 let cardinality = estimator.estimate(&filter);
1643 assert!(cardinality < 100.0);
1645 }
1646
1647 #[test]
1648 fn test_expand_cardinality() {
1649 let mut estimator = CardinalityEstimator::new();
1650 estimator.add_table_stats("Person", TableStats::new(100));
1651
1652 let expand = LogicalOperator::Expand(ExpandOp {
1653 from_variable: "a".to_string(),
1654 to_variable: "b".to_string(),
1655 edge_variable: None,
1656 direction: ExpandDirection::Outgoing,
1657 edge_types: vec![],
1658 min_hops: 1,
1659 max_hops: Some(1),
1660 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1661 variable: "a".to_string(),
1662 label: Some("Person".to_string()),
1663 input: None,
1664 })),
1665 path_alias: None,
1666 path_mode: PathMode::Walk,
1667 });
1668
1669 let cardinality = estimator.estimate(&expand);
1670 assert!(cardinality > 100.0);
1672 }
1673
1674 #[test]
1675 fn test_expand_with_edge_type_filter() {
1676 let mut estimator = CardinalityEstimator::new();
1677 estimator.add_table_stats("Person", TableStats::new(100));
1678
1679 let expand = LogicalOperator::Expand(ExpandOp {
1680 from_variable: "a".to_string(),
1681 to_variable: "b".to_string(),
1682 edge_variable: None,
1683 direction: ExpandDirection::Outgoing,
1684 edge_types: vec!["KNOWS".to_string()],
1685 min_hops: 1,
1686 max_hops: Some(1),
1687 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1688 variable: "a".to_string(),
1689 label: Some("Person".to_string()),
1690 input: None,
1691 })),
1692 path_alias: None,
1693 path_mode: PathMode::Walk,
1694 });
1695
1696 let cardinality = estimator.estimate(&expand);
1697 assert!(cardinality > 100.0);
1699 }
1700
1701 #[test]
1702 fn test_expand_variable_length() {
1703 let mut estimator = CardinalityEstimator::new();
1704 estimator.add_table_stats("Person", TableStats::new(100));
1705
1706 let expand = LogicalOperator::Expand(ExpandOp {
1707 from_variable: "a".to_string(),
1708 to_variable: "b".to_string(),
1709 edge_variable: None,
1710 direction: ExpandDirection::Outgoing,
1711 edge_types: vec![],
1712 min_hops: 1,
1713 max_hops: Some(3),
1714 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1715 variable: "a".to_string(),
1716 label: Some("Person".to_string()),
1717 input: None,
1718 })),
1719 path_alias: None,
1720 path_mode: PathMode::Walk,
1721 });
1722
1723 let cardinality = estimator.estimate(&expand);
1724 assert!(cardinality > 500.0);
1726 }
1727
1728 #[test]
1729 fn test_join_cross_product() {
1730 let mut estimator = CardinalityEstimator::new();
1731 estimator.add_table_stats("Person", TableStats::new(100));
1732 estimator.add_table_stats("Company", TableStats::new(50));
1733
1734 let join = LogicalOperator::Join(JoinOp {
1735 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1736 variable: "p".to_string(),
1737 label: Some("Person".to_string()),
1738 input: None,
1739 })),
1740 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1741 variable: "c".to_string(),
1742 label: Some("Company".to_string()),
1743 input: None,
1744 })),
1745 join_type: JoinType::Cross,
1746 conditions: vec![],
1747 });
1748
1749 let cardinality = estimator.estimate(&join);
1750 assert!((cardinality - 5000.0).abs() < 0.001);
1752 }
1753
1754 #[test]
1755 fn test_join_left_outer() {
1756 let mut estimator = CardinalityEstimator::new();
1757 estimator.add_table_stats("Person", TableStats::new(1000));
1758 estimator.add_table_stats("Company", TableStats::new(10));
1759
1760 let join = LogicalOperator::Join(JoinOp {
1761 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1762 variable: "p".to_string(),
1763 label: Some("Person".to_string()),
1764 input: None,
1765 })),
1766 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1767 variable: "c".to_string(),
1768 label: Some("Company".to_string()),
1769 input: None,
1770 })),
1771 join_type: JoinType::Left,
1772 conditions: vec![JoinCondition {
1773 left: LogicalExpression::Variable("p".to_string()),
1774 right: LogicalExpression::Variable("c".to_string()),
1775 }],
1776 });
1777
1778 let cardinality = estimator.estimate(&join);
1779 assert!(cardinality >= 1000.0);
1781 }
1782
1783 #[test]
1784 fn test_join_semi() {
1785 let mut estimator = CardinalityEstimator::new();
1786 estimator.add_table_stats("Person", TableStats::new(1000));
1787 estimator.add_table_stats("Company", TableStats::new(100));
1788
1789 let join = LogicalOperator::Join(JoinOp {
1790 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1791 variable: "p".to_string(),
1792 label: Some("Person".to_string()),
1793 input: None,
1794 })),
1795 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1796 variable: "c".to_string(),
1797 label: Some("Company".to_string()),
1798 input: None,
1799 })),
1800 join_type: JoinType::Semi,
1801 conditions: vec![],
1802 });
1803
1804 let cardinality = estimator.estimate(&join);
1805 assert!(cardinality <= 1000.0);
1807 }
1808
1809 #[test]
1810 fn test_join_anti() {
1811 let mut estimator = CardinalityEstimator::new();
1812 estimator.add_table_stats("Person", TableStats::new(1000));
1813 estimator.add_table_stats("Company", TableStats::new(100));
1814
1815 let join = LogicalOperator::Join(JoinOp {
1816 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1817 variable: "p".to_string(),
1818 label: Some("Person".to_string()),
1819 input: None,
1820 })),
1821 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1822 variable: "c".to_string(),
1823 label: Some("Company".to_string()),
1824 input: None,
1825 })),
1826 join_type: JoinType::Anti,
1827 conditions: vec![],
1828 });
1829
1830 let cardinality = estimator.estimate(&join);
1831 assert!(cardinality <= 1000.0);
1833 assert!(cardinality >= 1.0);
1834 }
1835
1836 #[test]
1837 fn test_project_preserves_cardinality() {
1838 let mut estimator = CardinalityEstimator::new();
1839 estimator.add_table_stats("Person", TableStats::new(1000));
1840
1841 let project = LogicalOperator::Project(ProjectOp {
1842 projections: vec![Projection {
1843 expression: LogicalExpression::Variable("n".to_string()),
1844 alias: None,
1845 }],
1846 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1847 variable: "n".to_string(),
1848 label: Some("Person".to_string()),
1849 input: None,
1850 })),
1851 pass_through_input: false,
1852 });
1853
1854 let cardinality = estimator.estimate(&project);
1855 assert!((cardinality - 1000.0).abs() < 0.001);
1856 }
1857
1858 #[test]
1859 fn test_sort_preserves_cardinality() {
1860 let mut estimator = CardinalityEstimator::new();
1861 estimator.add_table_stats("Person", TableStats::new(1000));
1862
1863 let sort = LogicalOperator::Sort(SortOp {
1864 keys: vec![SortKey {
1865 expression: LogicalExpression::Variable("n".to_string()),
1866 order: SortOrder::Ascending,
1867 nulls: None,
1868 }],
1869 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1870 variable: "n".to_string(),
1871 label: Some("Person".to_string()),
1872 input: None,
1873 })),
1874 });
1875
1876 let cardinality = estimator.estimate(&sort);
1877 assert!((cardinality - 1000.0).abs() < 0.001);
1878 }
1879
1880 #[test]
1881 fn test_distinct_reduces_cardinality() {
1882 let mut estimator = CardinalityEstimator::new();
1883 estimator.add_table_stats("Person", TableStats::new(1000));
1884
1885 let distinct = LogicalOperator::Distinct(DistinctOp {
1886 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1887 variable: "n".to_string(),
1888 label: Some("Person".to_string()),
1889 input: None,
1890 })),
1891 columns: None,
1892 });
1893
1894 let cardinality = estimator.estimate(&distinct);
1895 assert!((cardinality - 500.0).abs() < 0.001);
1897 }
1898
1899 #[test]
1900 fn test_skip_reduces_cardinality() {
1901 let mut estimator = CardinalityEstimator::new();
1902 estimator.add_table_stats("Person", TableStats::new(1000));
1903
1904 let skip = LogicalOperator::Skip(SkipOp {
1905 count: 100.into(),
1906 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1907 variable: "n".to_string(),
1908 label: Some("Person".to_string()),
1909 input: None,
1910 })),
1911 });
1912
1913 let cardinality = estimator.estimate(&skip);
1914 assert!((cardinality - 900.0).abs() < 0.001);
1915 }
1916
1917 #[test]
1918 fn test_return_preserves_cardinality() {
1919 let mut estimator = CardinalityEstimator::new();
1920 estimator.add_table_stats("Person", TableStats::new(1000));
1921
1922 let ret = LogicalOperator::Return(ReturnOp {
1923 items: vec![ReturnItem {
1924 expression: LogicalExpression::Variable("n".to_string()),
1925 alias: None,
1926 }],
1927 distinct: false,
1928 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1929 variable: "n".to_string(),
1930 label: Some("Person".to_string()),
1931 input: None,
1932 })),
1933 });
1934
1935 let cardinality = estimator.estimate(&ret);
1936 assert!((cardinality - 1000.0).abs() < 0.001);
1937 }
1938
1939 #[test]
1940 fn test_empty_cardinality() {
1941 let estimator = CardinalityEstimator::new();
1942 let cardinality = estimator.estimate(&LogicalOperator::Empty);
1943 assert!((cardinality).abs() < 0.001);
1944 }
1945
1946 #[test]
1947 fn test_table_stats_with_column() {
1948 let stats = TableStats::new(1000).with_column(
1949 "age",
1950 ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1951 );
1952
1953 assert_eq!(stats.row_count, 1000);
1954 let col = stats.columns.get("age").unwrap();
1955 assert_eq!(col.distinct_count, 50);
1956 assert_eq!(col.null_count, 10);
1957 assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1958 assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1959 }
1960
1961 #[test]
1962 fn test_estimator_default() {
1963 let estimator = CardinalityEstimator::default();
1964 let scan = LogicalOperator::NodeScan(NodeScanOp {
1965 variable: "n".to_string(),
1966 label: None,
1967 input: None,
1968 });
1969 let cardinality = estimator.estimate(&scan);
1970 assert!((cardinality - 1000.0).abs() < 0.001);
1971 }
1972
1973 #[test]
1974 fn test_set_avg_fanout() {
1975 let mut estimator = CardinalityEstimator::new();
1976 estimator.add_table_stats("Person", TableStats::new(100));
1977 estimator.set_avg_fanout(5.0);
1978
1979 let expand = LogicalOperator::Expand(ExpandOp {
1980 from_variable: "a".to_string(),
1981 to_variable: "b".to_string(),
1982 edge_variable: None,
1983 direction: ExpandDirection::Outgoing,
1984 edge_types: vec![],
1985 min_hops: 1,
1986 max_hops: Some(1),
1987 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1988 variable: "a".to_string(),
1989 label: Some("Person".to_string()),
1990 input: None,
1991 })),
1992 path_alias: None,
1993 path_mode: PathMode::Walk,
1994 });
1995
1996 let cardinality = estimator.estimate(&expand);
1997 assert!((cardinality - 500.0).abs() < 0.001);
1999 }
2000
2001 #[test]
2002 fn test_multiple_group_by_keys_reduce_cardinality() {
2003 let mut estimator = CardinalityEstimator::new();
2007 estimator.add_table_stats("Person", TableStats::new(10000));
2008
2009 let single_group = LogicalOperator::Aggregate(AggregateOp {
2010 group_by: vec![LogicalExpression::Property {
2011 variable: "n".to_string(),
2012 property: "city".to_string(),
2013 }],
2014 aggregates: vec![],
2015 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2016 variable: "n".to_string(),
2017 label: Some("Person".to_string()),
2018 input: None,
2019 })),
2020 having: None,
2021 });
2022
2023 let multi_group = LogicalOperator::Aggregate(AggregateOp {
2024 group_by: vec![
2025 LogicalExpression::Property {
2026 variable: "n".to_string(),
2027 property: "city".to_string(),
2028 },
2029 LogicalExpression::Property {
2030 variable: "n".to_string(),
2031 property: "country".to_string(),
2032 },
2033 ],
2034 aggregates: vec![],
2035 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2036 variable: "n".to_string(),
2037 label: Some("Person".to_string()),
2038 input: None,
2039 })),
2040 having: None,
2041 });
2042
2043 let single_card = estimator.estimate(&single_group);
2044 let multi_card = estimator.estimate(&multi_group);
2045
2046 assert!(single_card < 10000.0);
2048 assert!(multi_card < 10000.0);
2049 assert!(single_card >= 1.0);
2051 assert!(multi_card >= 1.0);
2052 }
2053
2054 #[test]
2057 fn test_histogram_build_uniform() {
2058 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2060 let histogram = EquiDepthHistogram::build(&values, 10);
2061
2062 assert_eq!(histogram.num_buckets(), 10);
2063 assert_eq!(histogram.total_rows(), 100);
2064
2065 for bucket in histogram.buckets() {
2067 assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
2068 }
2069 }
2070
2071 #[test]
2072 fn test_histogram_build_skewed() {
2073 let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
2075 values.extend((0..20).map(|i| 1000.0 + i as f64));
2076 let histogram = EquiDepthHistogram::build(&values, 5);
2077
2078 assert_eq!(histogram.num_buckets(), 5);
2079 assert_eq!(histogram.total_rows(), 100);
2080
2081 for bucket in histogram.buckets() {
2083 assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
2084 }
2085 }
2086
2087 #[test]
2088 fn test_histogram_range_selectivity_full() {
2089 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2090 let histogram = EquiDepthHistogram::build(&values, 10);
2091
2092 let selectivity = histogram.range_selectivity(None, None);
2094 assert!((selectivity - 1.0).abs() < 0.01);
2095 }
2096
2097 #[test]
2098 fn test_histogram_range_selectivity_half() {
2099 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2100 let histogram = EquiDepthHistogram::build(&values, 10);
2101
2102 let selectivity = histogram.range_selectivity(Some(50.0), None);
2104 assert!(selectivity > 0.4 && selectivity < 0.6);
2105 }
2106
2107 #[test]
2108 fn test_histogram_range_selectivity_quarter() {
2109 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2110 let histogram = EquiDepthHistogram::build(&values, 10);
2111
2112 let selectivity = histogram.range_selectivity(None, Some(25.0));
2114 assert!(selectivity > 0.2 && selectivity < 0.3);
2115 }
2116
2117 #[test]
2118 fn test_histogram_equality_selectivity() {
2119 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2120 let histogram = EquiDepthHistogram::build(&values, 10);
2121
2122 let selectivity = histogram.equality_selectivity(50.0);
2124 assert!(selectivity > 0.005 && selectivity < 0.02);
2125 }
2126
2127 #[test]
2128 fn test_histogram_empty() {
2129 let histogram = EquiDepthHistogram::build(&[], 10);
2130
2131 assert_eq!(histogram.num_buckets(), 0);
2132 assert_eq!(histogram.total_rows(), 0);
2133
2134 let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
2136 assert!((selectivity - 0.33).abs() < 0.01);
2137 }
2138
2139 #[test]
2140 fn test_histogram_bucket_overlap() {
2141 let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
2142
2143 assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
2145
2146 assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
2148
2149 assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
2151
2152 assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
2154
2155 assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
2157 }
2158
2159 #[test]
2160 fn test_column_stats_from_values() {
2161 let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2162 let stats = ColumnStats::from_values(values, 4);
2163
2164 assert_eq!(stats.distinct_count, 5); assert!(stats.min_value.is_some());
2166 assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2167 assert!(stats.max_value.is_some());
2168 assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2169 assert!(stats.histogram.is_some());
2170 }
2171
2172 #[test]
2173 fn test_column_stats_with_histogram_builder() {
2174 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2175 let histogram = EquiDepthHistogram::build(&values, 10);
2176
2177 let stats = ColumnStats::new(100)
2178 .with_range(0.0, 99.0)
2179 .with_histogram(histogram);
2180
2181 assert!(stats.histogram.is_some());
2182 assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2183 }
2184
2185 #[test]
2186 fn test_filter_with_histogram_stats() {
2187 let mut estimator = CardinalityEstimator::new();
2188
2189 let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2191 let histogram = EquiDepthHistogram::build(&age_values, 10);
2192 let age_stats = ColumnStats::new(62)
2193 .with_range(18.0, 79.0)
2194 .with_histogram(histogram);
2195
2196 estimator.add_table_stats(
2197 "Person",
2198 TableStats::new(1000).with_column("age", age_stats),
2199 );
2200
2201 let filter = LogicalOperator::Filter(FilterOp {
2204 predicate: LogicalExpression::Binary {
2205 left: Box::new(LogicalExpression::Property {
2206 variable: "n".to_string(),
2207 property: "age".to_string(),
2208 }),
2209 op: BinaryOp::Gt,
2210 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2211 },
2212 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2213 variable: "n".to_string(),
2214 label: Some("Person".to_string()),
2215 input: None,
2216 })),
2217 pushdown_hint: None,
2218 });
2219
2220 let cardinality = estimator.estimate(&filter);
2221
2222 assert!(cardinality > 300.0 && cardinality < 600.0);
2225 }
2226
2227 #[test]
2228 fn test_filter_equality_with_histogram() {
2229 let mut estimator = CardinalityEstimator::new();
2230
2231 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2233 let histogram = EquiDepthHistogram::build(&values, 10);
2234 let stats = ColumnStats::new(100)
2235 .with_range(0.0, 99.0)
2236 .with_histogram(histogram);
2237
2238 estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2239
2240 let filter = LogicalOperator::Filter(FilterOp {
2242 predicate: LogicalExpression::Binary {
2243 left: Box::new(LogicalExpression::Property {
2244 variable: "d".to_string(),
2245 property: "value".to_string(),
2246 }),
2247 op: BinaryOp::Eq,
2248 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2249 },
2250 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2251 variable: "d".to_string(),
2252 label: Some("Data".to_string()),
2253 input: None,
2254 })),
2255 pushdown_hint: None,
2256 });
2257
2258 let cardinality = estimator.estimate(&filter);
2259
2260 assert!((1.0..50.0).contains(&cardinality));
2263 }
2264
2265 #[test]
2266 fn test_histogram_min_max() {
2267 let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2268 let histogram = EquiDepthHistogram::build(&values, 2);
2269
2270 assert_eq!(histogram.min_value(), Some(5.0));
2271 assert!(histogram.max_value().is_some());
2273 }
2274
2275 #[test]
2278 fn test_selectivity_config_defaults() {
2279 let config = SelectivityConfig::new();
2280 assert!((config.default - 0.1).abs() < f64::EPSILON);
2281 assert!((config.equality - 0.01).abs() < f64::EPSILON);
2282 assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2283 assert!((config.range - 0.33).abs() < f64::EPSILON);
2284 assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2285 assert!((config.membership - 0.1).abs() < f64::EPSILON);
2286 assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2287 assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2288 assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2289 }
2290
2291 #[test]
2292 fn test_custom_selectivity_config() {
2293 let config = SelectivityConfig {
2294 equality: 0.05,
2295 range: 0.25,
2296 ..SelectivityConfig::new()
2297 };
2298 let estimator = CardinalityEstimator::with_selectivity_config(config);
2299 assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2300 assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2301 }
2302
2303 #[test]
2304 fn test_custom_selectivity_affects_estimation() {
2305 let mut default_est = CardinalityEstimator::new();
2307 default_est.add_table_stats("Person", TableStats::new(1000));
2308
2309 let filter = LogicalOperator::Filter(FilterOp {
2310 predicate: LogicalExpression::Binary {
2311 left: Box::new(LogicalExpression::Property {
2312 variable: "n".to_string(),
2313 property: "name".to_string(),
2314 }),
2315 op: BinaryOp::Eq,
2316 right: Box::new(LogicalExpression::Literal(Value::String("Alix".into()))),
2317 },
2318 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2319 variable: "n".to_string(),
2320 label: Some("Person".to_string()),
2321 input: None,
2322 })),
2323 pushdown_hint: None,
2324 });
2325
2326 let default_card = default_est.estimate(&filter);
2327
2328 let config = SelectivityConfig {
2330 equality: 0.2,
2331 ..SelectivityConfig::new()
2332 };
2333 let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2334 custom_est.add_table_stats("Person", TableStats::new(1000));
2335
2336 let custom_card = custom_est.estimate(&filter);
2337
2338 assert!(custom_card > default_card);
2339 assert!((custom_card - 200.0).abs() < 1.0);
2340 }
2341
2342 #[test]
2343 fn test_custom_range_selectivity() {
2344 let config = SelectivityConfig {
2345 range: 0.5,
2346 ..SelectivityConfig::new()
2347 };
2348 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2349 estimator.add_table_stats("Person", TableStats::new(1000));
2350
2351 let filter = LogicalOperator::Filter(FilterOp {
2352 predicate: LogicalExpression::Binary {
2353 left: Box::new(LogicalExpression::Property {
2354 variable: "n".to_string(),
2355 property: "age".to_string(),
2356 }),
2357 op: BinaryOp::Gt,
2358 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2359 },
2360 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2361 variable: "n".to_string(),
2362 label: Some("Person".to_string()),
2363 input: None,
2364 })),
2365 pushdown_hint: None,
2366 });
2367
2368 let cardinality = estimator.estimate(&filter);
2369 assert!((cardinality - 500.0).abs() < 1.0);
2371 }
2372
2373 #[test]
2374 fn test_custom_distinct_fraction() {
2375 let config = SelectivityConfig {
2376 distinct_fraction: 0.8,
2377 ..SelectivityConfig::new()
2378 };
2379 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2380 estimator.add_table_stats("Person", TableStats::new(1000));
2381
2382 let distinct = LogicalOperator::Distinct(DistinctOp {
2383 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2384 variable: "n".to_string(),
2385 label: Some("Person".to_string()),
2386 input: None,
2387 })),
2388 columns: None,
2389 });
2390
2391 let cardinality = estimator.estimate(&distinct);
2392 assert!((cardinality - 800.0).abs() < 1.0);
2394 }
2395
2396 #[test]
2399 fn test_estimation_log_basic() {
2400 let mut log = EstimationLog::new(10.0);
2401 log.record("NodeScan(Person)", 1000.0, 1200.0);
2402 log.record("Filter(age > 30)", 100.0, 90.0);
2403
2404 assert_eq!(log.entries().len(), 2);
2405 assert!(!log.should_replan()); }
2407
2408 #[test]
2409 fn test_estimation_log_triggers_replan() {
2410 let mut log = EstimationLog::new(10.0);
2411 log.record("NodeScan(Person)", 100.0, 5000.0); assert!(log.should_replan());
2414 }
2415
2416 #[test]
2417 fn test_estimation_log_overestimate_triggers_replan() {
2418 let mut log = EstimationLog::new(5.0);
2419 log.record("Filter", 1000.0, 100.0); assert!(log.should_replan()); }
2423
2424 #[test]
2425 fn test_estimation_entry_error_ratio() {
2426 let entry = EstimationEntry {
2427 operator: "test".into(),
2428 estimated: 100.0,
2429 actual: 200.0,
2430 };
2431 assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2432
2433 let perfect = EstimationEntry {
2434 operator: "test".into(),
2435 estimated: 100.0,
2436 actual: 100.0,
2437 };
2438 assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2439
2440 let zero_est = EstimationEntry {
2441 operator: "test".into(),
2442 estimated: 0.0,
2443 actual: 0.0,
2444 };
2445 assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2446 }
2447
2448 #[test]
2449 fn test_estimation_log_max_error_ratio() {
2450 let mut log = EstimationLog::new(10.0);
2451 log.record("A", 100.0, 300.0); log.record("B", 100.0, 50.0); log.record("C", 100.0, 100.0); assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2456 }
2457
2458 #[test]
2459 fn test_estimation_log_clear() {
2460 let mut log = EstimationLog::new(10.0);
2461 log.record("A", 100.0, 100.0);
2462 assert_eq!(log.entries().len(), 1);
2463
2464 log.clear();
2465 assert!(log.entries().is_empty());
2466 assert!(!log.should_replan());
2467 }
2468
2469 #[test]
2470 fn test_create_estimation_log() {
2471 let log = CardinalityEstimator::create_estimation_log();
2472 assert!(log.entries().is_empty());
2473 assert!(!log.should_replan());
2474 }
2475
2476 #[test]
2477 fn test_equality_selectivity_empty_histogram() {
2478 let hist = EquiDepthHistogram::new(vec![]);
2479 assert_eq!(hist.equality_selectivity(5.0), 0.01);
2481 }
2482
2483 #[test]
2484 fn test_equality_selectivity_value_in_bucket() {
2485 let values: Vec<f64> = (1..=10).map(|i| i as f64).collect();
2486 let hist = EquiDepthHistogram::build(&values, 2);
2487 let sel = hist.equality_selectivity(3.0);
2488 assert!(sel > 0.0);
2489 assert!(sel <= 1.0);
2490 }
2491
2492 #[test]
2493 fn test_equality_selectivity_value_outside_all_buckets() {
2494 let values: Vec<f64> = (1..=10).map(|i| i as f64).collect();
2495 let hist = EquiDepthHistogram::build(&values, 2);
2496 let sel = hist.equality_selectivity(9999.0);
2498 assert_eq!(sel, 0.001);
2499 }
2500
2501 #[test]
2502 fn test_histogram_min_max_empty() {
2503 let hist = EquiDepthHistogram::new(vec![]);
2504 assert_eq!(hist.min_value(), None);
2505 assert_eq!(hist.max_value(), None);
2506 }
2507
2508 #[test]
2509 fn test_histogram_min_max_single_bucket() {
2510 let hist = EquiDepthHistogram::new(vec![HistogramBucket::new(1.0, 10.0, 5, 5)]);
2511 assert_eq!(hist.min_value(), Some(1.0));
2512 assert_eq!(hist.max_value(), Some(10.0));
2513 }
2514
2515 #[test]
2516 fn test_histogram_min_max_multi_bucket() {
2517 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0];
2518 let hist = EquiDepthHistogram::build(&values, 3);
2519 let min = hist.min_value().unwrap();
2520 let max = hist.max_value().unwrap();
2521 assert!((min - 1.0).abs() < 1e-9, "min should be 1.0, got {min}");
2522 assert!(max >= 20.0, "max should be >= last value, got {max}");
2523 }
2524
2525 #[test]
2526 fn test_count_and_conjuncts_single_expression() {
2527 use crate::query::plan::LogicalExpression;
2528 let expr = LogicalExpression::Literal(Value::Bool(true));
2529 assert_eq!(count_and_conjuncts(&expr), 1);
2530 }
2531
2532 #[test]
2533 fn test_count_and_conjuncts_flat_and() {
2534 use crate::query::plan::{BinaryOp, LogicalExpression};
2535 let expr = LogicalExpression::Binary {
2536 left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
2537 op: BinaryOp::And,
2538 right: Box::new(LogicalExpression::Literal(Value::Bool(false))),
2539 };
2540 assert_eq!(count_and_conjuncts(&expr), 2);
2541 }
2542
2543 #[test]
2544 fn test_count_and_conjuncts_nested_and() {
2545 use crate::query::plan::{BinaryOp, LogicalExpression};
2546 let ab = LogicalExpression::Binary {
2547 left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
2548 op: BinaryOp::And,
2549 right: Box::new(LogicalExpression::Literal(Value::Bool(false))),
2550 };
2551 let cd = LogicalExpression::Binary {
2552 left: Box::new(LogicalExpression::Literal(Value::Int64(1))),
2553 op: BinaryOp::And,
2554 right: Box::new(LogicalExpression::Literal(Value::Int64(2))),
2555 };
2556 let expr = LogicalExpression::Binary {
2557 left: Box::new(ab),
2558 op: BinaryOp::And,
2559 right: Box::new(cd),
2560 };
2561 assert_eq!(count_and_conjuncts(&expr), 4);
2562 }
2563
2564 #[test]
2565 fn test_count_distinct_empty() {
2566 assert_eq!(count_distinct(&[]), 0);
2567 }
2568
2569 #[test]
2570 fn test_count_distinct_all_unique() {
2571 assert_eq!(count_distinct(&[1.0, 2.0, 3.0, 4.0]), 4);
2572 }
2573
2574 #[test]
2575 fn test_count_distinct_with_duplicates() {
2576 assert_eq!(count_distinct(&[1.0, 1.0, 2.0, 2.0, 3.0]), 3);
2577 }
2578
2579 #[test]
2580 fn test_count_distinct_all_same() {
2581 assert_eq!(count_distinct(&[5.0, 5.0, 5.0]), 1);
2582 }
2583
2584 #[test]
2585 fn test_count_distinct_single_value() {
2586 assert_eq!(count_distinct(&[42.0]), 1);
2587 }
2588
2589 #[test]
2593 fn test_estimate_vector_scan_topk_and_threshold() {
2594 use crate::query::plan::VectorScanOp;
2595
2596 let estimator = CardinalityEstimator::new();
2597
2598 let plain = LogicalOperator::VectorScan(VectorScanOp {
2600 variable: "n".to_string(),
2601 index_name: None,
2602 property: "embedding".to_string(),
2603 label: None,
2604 query_vector: LogicalExpression::Variable("q".to_string()),
2605 k: Some(10),
2606 metric: None,
2607 min_similarity: None,
2608 max_distance: None,
2609 input: None,
2610 });
2611 let plain_card = estimator.estimate(&plain);
2612 assert!(plain_card <= 10.0);
2613 assert!((plain_card - 10.0).abs() < 1e-9);
2614
2615 let with_threshold = LogicalOperator::VectorScan(VectorScanOp {
2617 variable: "n".to_string(),
2618 index_name: None,
2619 property: "embedding".to_string(),
2620 label: None,
2621 query_vector: LogicalExpression::Variable("q".to_string()),
2622 k: Some(10),
2623 metric: None,
2624 min_similarity: Some(0.8),
2625 max_distance: None,
2626 input: None,
2627 });
2628 let filtered = estimator.estimate(&with_threshold);
2629 assert!(filtered < plain_card);
2630 assert!(filtered >= 1.0);
2631 assert!((filtered - 7.0).abs() < 1e-9);
2632 }
2633
2634 #[test]
2637 fn test_estimate_text_scan_topk_and_threshold() {
2638 use crate::query::plan::VectorJoinOp;
2639
2640 let mut estimator = CardinalityEstimator::new();
2641 estimator.add_table_stats("Article", TableStats::new(40));
2642
2643 let input = LogicalOperator::NodeScan(NodeScanOp {
2644 variable: "a".to_string(),
2645 label: Some("Article".to_string()),
2646 input: None,
2647 });
2648
2649 let plain = LogicalOperator::VectorJoin(VectorJoinOp {
2651 input: Box::new(input.clone()),
2652 left_vector_variable: None,
2653 left_property: None,
2654 query_vector: LogicalExpression::Variable("q".to_string()),
2655 right_variable: "m".to_string(),
2656 right_property: "emb".to_string(),
2657 right_label: None,
2658 index_name: None,
2659 k: 5,
2660 metric: None,
2661 min_similarity: None,
2662 max_distance: None,
2663 score_variable: None,
2664 });
2665 let plain_card = estimator.estimate(&plain);
2666 assert!((plain_card - 200.0).abs() < 1e-9);
2667
2668 let with_threshold = LogicalOperator::VectorJoin(VectorJoinOp {
2670 input: Box::new(input),
2671 left_vector_variable: None,
2672 left_property: None,
2673 query_vector: LogicalExpression::Variable("q".to_string()),
2674 right_variable: "m".to_string(),
2675 right_property: "emb".to_string(),
2676 right_label: None,
2677 index_name: None,
2678 k: 5,
2679 metric: None,
2680 min_similarity: Some(0.5),
2681 max_distance: None,
2682 score_variable: None,
2683 });
2684 let filtered = estimator.estimate(&with_threshold);
2685 assert!(filtered < plain_card);
2686 assert!((filtered - 140.0).abs() < 1e-9);
2687 }
2688
2689 #[test]
2692 fn test_estimate_multi_way_join_agm_bound() {
2693 let mut estimator = CardinalityEstimator::new();
2694 estimator.add_table_stats("Person", TableStats::new(1000));
2696 estimator.add_table_stats("Works", TableStats::new(50));
2697 estimator.add_table_stats("Company", TableStats::new(200));
2698
2699 let mwj = LogicalOperator::MultiWayJoin(MultiWayJoinOp {
2700 inputs: vec![
2701 LogicalOperator::NodeScan(NodeScanOp {
2702 variable: "p".to_string(),
2703 label: Some("Person".to_string()),
2704 input: None,
2705 }),
2706 LogicalOperator::NodeScan(NodeScanOp {
2707 variable: "w".to_string(),
2708 label: Some("Works".to_string()),
2709 input: None,
2710 }),
2711 LogicalOperator::NodeScan(NodeScanOp {
2712 variable: "c".to_string(),
2713 label: Some("Company".to_string()),
2714 input: None,
2715 }),
2716 ],
2717 conditions: vec![],
2718 shared_variables: vec!["p".to_string()],
2719 });
2720
2721 let card = estimator.estimate(&mwj);
2722 let expected = 50.0_f64.powf(1.5);
2724 assert!(
2725 (card - expected).abs() < 0.01,
2726 "got {card}, expected {expected}"
2727 );
2728 assert!(card < 1000.0 * 50.0 * 200.0);
2730 }
2731
2732 #[test]
2734 fn test_estimate_multi_way_join_empty_inputs() {
2735 let estimator = CardinalityEstimator::new();
2736 let mwj = LogicalOperator::MultiWayJoin(MultiWayJoinOp {
2737 inputs: vec![],
2738 conditions: vec![],
2739 shared_variables: vec![],
2740 });
2741 assert!(estimator.estimate(&mwj).abs() < f64::EPSILON);
2742 }
2743
2744 #[test]
2747 fn test_range_selectivity_with_histogram_fallback() {
2748 let mut estimator = CardinalityEstimator::new();
2749 let age_stats = ColumnStats::new(62).with_range(18.0, 80.0);
2751 estimator.add_table_stats(
2752 "Person",
2753 TableStats::new(1000).with_column("age", age_stats),
2754 );
2755
2756 let predicate = LogicalExpression::Binary {
2758 left: Box::new(LogicalExpression::Binary {
2759 left: Box::new(LogicalExpression::Property {
2760 variable: "n".to_string(),
2761 property: "age".to_string(),
2762 }),
2763 op: BinaryOp::Ge,
2764 right: Box::new(LogicalExpression::Literal(Value::Int64(25))),
2765 }),
2766 op: BinaryOp::And,
2767 right: Box::new(LogicalExpression::Binary {
2768 left: Box::new(LogicalExpression::Property {
2769 variable: "n".to_string(),
2770 property: "age".to_string(),
2771 }),
2772 op: BinaryOp::Le,
2773 right: Box::new(LogicalExpression::Literal(Value::Int64(65))),
2774 }),
2775 };
2776 let filter = LogicalOperator::Filter(FilterOp {
2777 predicate,
2778 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2779 variable: "n".to_string(),
2780 label: Some("Person".to_string()),
2781 input: None,
2782 })),
2783 pushdown_hint: None,
2784 });
2785
2786 let card = estimator.estimate(&filter);
2787 assert!(card < 1000.0);
2792 assert!(card > 10.0);
2793 }
2794}