1use crate::query::plan::{
18 AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19 LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
20 VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31 pub lower_bound: f64,
33 pub upper_bound: f64,
35 pub frequency: u64,
37 pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42 #[must_use]
44 pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45 Self {
46 lower_bound,
47 upper_bound,
48 frequency,
49 distinct_count,
50 }
51 }
52
53 #[must_use]
55 pub fn width(&self) -> f64 {
56 self.upper_bound - self.lower_bound
57 }
58
59 #[must_use]
61 pub fn contains(&self, value: f64) -> bool {
62 value >= self.lower_bound && value < self.upper_bound
63 }
64
65 #[must_use]
67 pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68 let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69 let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71 let bucket_width = self.width();
72 if bucket_width <= 0.0 {
73 return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74 1.0
75 } else {
76 0.0
77 };
78 }
79
80 let overlap = (effective_upper - effective_lower).max(0.0);
81 (overlap / bucket_width).min(1.0)
82 }
83}
84
85#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106 buckets: Vec<HistogramBucket>,
108 total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113 #[must_use]
115 pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116 let total_rows = buckets.iter().map(|b| b.frequency).sum();
117 Self {
118 buckets,
119 total_rows,
120 }
121 }
122
123 #[must_use]
132 pub fn build(values: &[f64], num_buckets: usize) -> Self {
133 if values.is_empty() || num_buckets == 0 {
134 return Self {
135 buckets: Vec::new(),
136 total_rows: 0,
137 };
138 }
139
140 let num_buckets = num_buckets.min(values.len());
141 let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142 let mut buckets = Vec::with_capacity(num_buckets);
143
144 let mut start_idx = 0;
145 while start_idx < values.len() {
146 let end_idx = (start_idx + rows_per_bucket).min(values.len());
147 let lower_bound = values[start_idx];
148 let upper_bound = if end_idx < values.len() {
149 values[end_idx]
150 } else {
151 values[end_idx - 1] + 1.0
153 };
154
155 let bucket_values = &values[start_idx..end_idx];
157 let distinct_count = count_distinct(bucket_values);
158
159 buckets.push(HistogramBucket::new(
160 lower_bound,
161 upper_bound,
162 (end_idx - start_idx) as u64,
163 distinct_count,
164 ));
165
166 start_idx = end_idx;
167 }
168
169 Self::new(buckets)
170 }
171
172 #[must_use]
174 pub fn num_buckets(&self) -> usize {
175 self.buckets.len()
176 }
177
178 #[must_use]
180 pub fn total_rows(&self) -> u64 {
181 self.total_rows
182 }
183
184 #[must_use]
186 pub fn buckets(&self) -> &[HistogramBucket] {
187 &self.buckets
188 }
189
190 #[must_use]
199 pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200 if self.buckets.is_empty() || self.total_rows == 0 {
201 return 0.33; }
203
204 let mut matching_rows = 0.0;
205
206 for bucket in &self.buckets {
207 let bucket_lower = bucket.lower_bound;
209 let bucket_upper = bucket.upper_bound;
210
211 if let Some(l) = lower
213 && bucket_upper <= l
214 {
215 continue;
216 }
217 if let Some(u) = upper
218 && bucket_lower >= u
219 {
220 continue;
221 }
222
223 let overlap = bucket.overlap_fraction(lower, upper);
225 matching_rows += overlap * bucket.frequency as f64;
226 }
227
228 (matching_rows / self.total_rows as f64).min(1.0).max(0.0)
229 }
230
231 #[must_use]
235 pub fn equality_selectivity(&self, value: f64) -> f64 {
236 if self.buckets.is_empty() || self.total_rows == 0 {
237 return 0.01; }
239
240 for bucket in &self.buckets {
242 if bucket.contains(value) {
243 if bucket.distinct_count > 0 {
245 return (bucket.frequency as f64
246 / bucket.distinct_count as f64
247 / self.total_rows as f64)
248 .min(1.0);
249 }
250 }
251 }
252
253 0.001
255 }
256
257 #[must_use]
259 pub fn min_value(&self) -> Option<f64> {
260 self.buckets.first().map(|b| b.lower_bound)
261 }
262
263 #[must_use]
265 pub fn max_value(&self) -> Option<f64> {
266 self.buckets.last().map(|b| b.upper_bound)
267 }
268}
269
270fn count_distinct(sorted_values: &[f64]) -> u64 {
272 if sorted_values.is_empty() {
273 return 0;
274 }
275
276 let mut count = 1u64;
277 let mut prev = sorted_values[0];
278
279 for &val in &sorted_values[1..] {
280 if (val - prev).abs() > f64::EPSILON {
281 count += 1;
282 prev = val;
283 }
284 }
285
286 count
287}
288
289#[derive(Debug, Clone)]
291pub struct TableStats {
292 pub row_count: u64,
294 pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299 #[must_use]
301 pub fn new(row_count: u64) -> Self {
302 Self {
303 row_count,
304 columns: HashMap::new(),
305 }
306 }
307
308 pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310 self.columns.insert(name.to_string(), stats);
311 self
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct ColumnStats {
318 pub distinct_count: u64,
320 pub null_count: u64,
322 pub min_value: Option<f64>,
324 pub max_value: Option<f64>,
326 pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331 #[must_use]
333 pub fn new(distinct_count: u64) -> Self {
334 Self {
335 distinct_count,
336 null_count: 0,
337 min_value: None,
338 max_value: None,
339 histogram: None,
340 }
341 }
342
343 #[must_use]
345 pub fn with_nulls(mut self, null_count: u64) -> Self {
346 self.null_count = null_count;
347 self
348 }
349
350 #[must_use]
352 pub fn with_range(mut self, min: f64, max: f64) -> Self {
353 self.min_value = Some(min);
354 self.max_value = Some(max);
355 self
356 }
357
358 #[must_use]
360 pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361 self.histogram = Some(histogram);
362 self
363 }
364
365 #[must_use]
373 pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374 if values.is_empty() {
375 return Self::new(0);
376 }
377
378 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381 let min = values.first().copied();
382 let max = values.last().copied();
383 let distinct_count = count_distinct(&values);
384 let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386 Self {
387 distinct_count,
388 null_count: 0,
389 min_value: min,
390 max_value: max,
391 histogram: Some(histogram),
392 }
393 }
394}
395
396pub struct CardinalityEstimator {
398 table_stats: HashMap<String, TableStats>,
400 default_row_count: u64,
402 default_selectivity: f64,
404 avg_fanout: f64,
406}
407
408impl CardinalityEstimator {
409 #[must_use]
411 pub fn new() -> Self {
412 Self {
413 table_stats: HashMap::new(),
414 default_row_count: 1000,
415 default_selectivity: 0.1,
416 avg_fanout: 10.0,
417 }
418 }
419
420 #[must_use]
426 pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
427 let mut estimator = Self::new();
428
429 if stats.total_nodes > 0 {
431 estimator.default_row_count = stats.total_nodes;
432 }
433
434 for (label, label_stats) in &stats.labels {
436 let mut table_stats = TableStats::new(label_stats.node_count);
437
438 for (prop, col_stats) in &label_stats.properties {
440 let optimizer_col =
441 ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
442 table_stats = table_stats.with_column(prop, optimizer_col);
443 }
444
445 estimator.add_table_stats(label, table_stats);
446 }
447
448 if !stats.edge_types.is_empty() {
450 let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
451 estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
452 } else if stats.total_nodes > 0 {
453 estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
454 }
455
456 if estimator.avg_fanout < 1.0 {
458 estimator.avg_fanout = 1.0;
459 }
460
461 estimator
462 }
463
464 pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
466 self.table_stats.insert(name.to_string(), stats);
467 }
468
469 pub fn set_avg_fanout(&mut self, fanout: f64) {
471 self.avg_fanout = fanout;
472 }
473
474 #[must_use]
476 pub fn estimate(&self, op: &LogicalOperator) -> f64 {
477 match op {
478 LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
479 LogicalOperator::Filter(filter) => self.estimate_filter(filter),
480 LogicalOperator::Project(project) => self.estimate_project(project),
481 LogicalOperator::Expand(expand) => self.estimate_expand(expand),
482 LogicalOperator::Join(join) => self.estimate_join(join),
483 LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
484 LogicalOperator::Sort(sort) => self.estimate_sort(sort),
485 LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
486 LogicalOperator::Limit(limit) => self.estimate_limit(limit),
487 LogicalOperator::Skip(skip) => self.estimate_skip(skip),
488 LogicalOperator::Return(ret) => self.estimate(&ret.input),
489 LogicalOperator::Empty => 0.0,
490 LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
491 LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
492 _ => self.default_row_count as f64,
493 }
494 }
495
496 fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
498 if let Some(label) = &scan.label
499 && let Some(stats) = self.table_stats.get(label)
500 {
501 return stats.row_count as f64;
502 }
503 self.default_row_count as f64
505 }
506
507 fn estimate_filter(&self, filter: &FilterOp) -> f64 {
509 let input_cardinality = self.estimate(&filter.input);
510 let selectivity = self.estimate_selectivity(&filter.predicate);
511 (input_cardinality * selectivity).max(1.0)
512 }
513
514 fn estimate_project(&self, project: &ProjectOp) -> f64 {
516 self.estimate(&project.input)
517 }
518
519 fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
521 let input_cardinality = self.estimate(&expand.input);
522
523 let fanout = if expand.edge_type.is_some() {
525 self.avg_fanout * 0.5
527 } else {
528 self.avg_fanout
529 };
530
531 let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
533 let min = expand.min_hops as f64;
534 let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
535 (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
537 } else {
538 fanout
539 };
540
541 (input_cardinality * path_multiplier).max(1.0)
542 }
543
544 fn estimate_join(&self, join: &JoinOp) -> f64 {
546 let left_card = self.estimate(&join.left);
547 let right_card = self.estimate(&join.right);
548
549 match join.join_type {
550 JoinType::Cross => left_card * right_card,
551 JoinType::Inner => {
552 let selectivity = if join.conditions.is_empty() {
554 1.0 } else {
556 0.1_f64.powi(join.conditions.len() as i32)
558 };
559 (left_card * right_card * selectivity).max(1.0)
560 }
561 JoinType::Left => {
562 let inner_card = self.estimate_join(&JoinOp {
564 left: join.left.clone(),
565 right: join.right.clone(),
566 join_type: JoinType::Inner,
567 conditions: join.conditions.clone(),
568 });
569 inner_card.max(left_card)
570 }
571 JoinType::Right => {
572 let inner_card = self.estimate_join(&JoinOp {
574 left: join.left.clone(),
575 right: join.right.clone(),
576 join_type: JoinType::Inner,
577 conditions: join.conditions.clone(),
578 });
579 inner_card.max(right_card)
580 }
581 JoinType::Full => {
582 let inner_card = self.estimate_join(&JoinOp {
584 left: join.left.clone(),
585 right: join.right.clone(),
586 join_type: JoinType::Inner,
587 conditions: join.conditions.clone(),
588 });
589 inner_card.max(left_card.max(right_card))
590 }
591 JoinType::Semi => {
592 (left_card * self.default_selectivity).max(1.0)
594 }
595 JoinType::Anti => {
596 (left_card * (1.0 - self.default_selectivity)).max(1.0)
598 }
599 }
600 }
601
602 fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
604 let input_cardinality = self.estimate(&agg.input);
605
606 if agg.group_by.is_empty() {
607 1.0
609 } else {
610 let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
613 (input_cardinality / group_reduction).max(1.0)
614 }
615 }
616
617 fn estimate_sort(&self, sort: &SortOp) -> f64 {
619 self.estimate(&sort.input)
620 }
621
622 fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
624 let input_cardinality = self.estimate(&distinct.input);
625 (input_cardinality * 0.5).max(1.0)
627 }
628
629 fn estimate_limit(&self, limit: &LimitOp) -> f64 {
631 let input_cardinality = self.estimate(&limit.input);
632 (limit.count as f64).min(input_cardinality)
633 }
634
635 fn estimate_skip(&self, skip: &SkipOp) -> f64 {
637 let input_cardinality = self.estimate(&skip.input);
638 (input_cardinality - skip.count as f64).max(0.0)
639 }
640
641 fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
646 let base_k = scan.k as f64;
647
648 let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
650 0.7
652 } else {
653 1.0
654 };
655
656 (base_k * selectivity).max(1.0)
657 }
658
659 fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
663 let input_cardinality = self.estimate(&join.input);
664 let k = join.k as f64;
665
666 let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
668 0.7
669 } else {
670 1.0
671 };
672
673 (input_cardinality * k * selectivity).max(1.0)
674 }
675
676 fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
678 match expr {
679 LogicalExpression::Binary { left, op, right } => {
680 self.estimate_binary_selectivity(left, *op, right)
681 }
682 LogicalExpression::Unary { op, operand } => {
683 self.estimate_unary_selectivity(*op, operand)
684 }
685 LogicalExpression::Literal(value) => {
686 if let grafeo_common::types::Value::Bool(b) = value {
688 if *b { 1.0 } else { 0.0 }
689 } else {
690 self.default_selectivity
691 }
692 }
693 _ => self.default_selectivity,
694 }
695 }
696
697 fn estimate_binary_selectivity(
699 &self,
700 left: &LogicalExpression,
701 op: BinaryOp,
702 right: &LogicalExpression,
703 ) -> f64 {
704 match op {
705 BinaryOp::Eq => {
707 if let Some(selectivity) = self.try_equality_selectivity(left, right) {
708 return selectivity;
709 }
710 0.01
711 }
712 BinaryOp::Ne => 0.99,
714 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
716 if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
717 return selectivity;
718 }
719 0.33
720 }
721 BinaryOp::And => {
723 let left_sel = self.estimate_selectivity(left);
724 let right_sel = self.estimate_selectivity(right);
725 left_sel * right_sel
727 }
728 BinaryOp::Or => {
729 let left_sel = self.estimate_selectivity(left);
730 let right_sel = self.estimate_selectivity(right);
731 (left_sel + right_sel - left_sel * right_sel).min(1.0)
734 }
735 BinaryOp::StartsWith => 0.1,
737 BinaryOp::EndsWith => 0.1,
738 BinaryOp::Contains => 0.1,
739 BinaryOp::Like => 0.1,
740 BinaryOp::In => 0.1,
742 _ => self.default_selectivity,
744 }
745 }
746
747 fn try_equality_selectivity(
749 &self,
750 left: &LogicalExpression,
751 right: &LogicalExpression,
752 ) -> Option<f64> {
753 let (label, column, value) = self.extract_column_and_value(left, right)?;
755
756 let stats = self.get_column_stats(&label, &column)?;
758
759 if let Some(ref histogram) = stats.histogram {
761 return Some(histogram.equality_selectivity(value));
762 }
763
764 if stats.distinct_count > 0 {
766 return Some(1.0 / stats.distinct_count as f64);
767 }
768
769 None
770 }
771
772 fn try_range_selectivity(
774 &self,
775 left: &LogicalExpression,
776 op: BinaryOp,
777 right: &LogicalExpression,
778 ) -> Option<f64> {
779 let (label, column, value) = self.extract_column_and_value(left, right)?;
781
782 let stats = self.get_column_stats(&label, &column)?;
784
785 let (lower, upper) = match op {
787 BinaryOp::Lt => (None, Some(value)),
788 BinaryOp::Le => (None, Some(value + f64::EPSILON)),
789 BinaryOp::Gt => (Some(value + f64::EPSILON), None),
790 BinaryOp::Ge => (Some(value), None),
791 _ => return None,
792 };
793
794 if let Some(ref histogram) = stats.histogram {
796 return Some(histogram.range_selectivity(lower, upper));
797 }
798
799 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
801 let range = max - min;
802 if range <= 0.0 {
803 return Some(1.0);
804 }
805
806 let effective_lower = lower.unwrap_or(min).max(min);
807 let effective_upper = upper.unwrap_or(max).min(max);
808 let overlap = (effective_upper - effective_lower).max(0.0);
809 return Some((overlap / range).min(1.0).max(0.0));
810 }
811
812 None
813 }
814
815 fn extract_column_and_value(
820 &self,
821 left: &LogicalExpression,
822 right: &LogicalExpression,
823 ) -> Option<(String, String, f64)> {
824 if let Some(result) = self.try_extract_property_literal(left, right) {
826 return Some(result);
827 }
828
829 self.try_extract_property_literal(right, left)
831 }
832
833 fn try_extract_property_literal(
835 &self,
836 property_expr: &LogicalExpression,
837 literal_expr: &LogicalExpression,
838 ) -> Option<(String, String, f64)> {
839 let (variable, property) = match property_expr {
841 LogicalExpression::Property { variable, property } => {
842 (variable.clone(), property.clone())
843 }
844 _ => return None,
845 };
846
847 let value = match literal_expr {
849 LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
850 LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
851 _ => return None,
852 };
853
854 for label in self.table_stats.keys() {
858 if let Some(stats) = self.table_stats.get(label)
859 && stats.columns.contains_key(&property)
860 {
861 return Some((label.clone(), property, value));
862 }
863 }
864
865 Some((variable, property, value))
867 }
868
869 fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
871 match op {
872 UnaryOp::Not => 1.0 - self.default_selectivity,
873 UnaryOp::IsNull => 0.05, UnaryOp::IsNotNull => 0.95,
875 UnaryOp::Neg => 1.0, }
877 }
878
879 fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
881 self.table_stats.get(label)?.columns.get(column)
882 }
883
884 #[allow(dead_code)]
886 fn estimate_equality_with_stats(&self, label: &str, column: &str) -> f64 {
887 if let Some(stats) = self.get_column_stats(label, column)
888 && stats.distinct_count > 0
889 {
890 return 1.0 / stats.distinct_count as f64;
891 }
892 0.01 }
894
895 #[allow(dead_code)]
897 fn estimate_range_with_stats(
898 &self,
899 label: &str,
900 column: &str,
901 lower: Option<f64>,
902 upper: Option<f64>,
903 ) -> f64 {
904 if let Some(stats) = self.get_column_stats(label, column)
905 && let (Some(min), Some(max)) = (stats.min_value, stats.max_value)
906 {
907 let range = max - min;
908 if range <= 0.0 {
909 return 1.0;
910 }
911
912 let effective_lower = lower.unwrap_or(min).max(min);
913 let effective_upper = upper.unwrap_or(max).min(max);
914
915 let overlap = (effective_upper - effective_lower).max(0.0);
916 return (overlap / range).min(1.0).max(0.0);
917 }
918 0.33 }
920}
921
922impl Default for CardinalityEstimator {
923 fn default() -> Self {
924 Self::new()
925 }
926}
927
928#[cfg(test)]
929mod tests {
930 use super::*;
931 use crate::query::plan::{
932 DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, ProjectOp,
933 Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
934 };
935 use grafeo_common::types::Value;
936
937 #[test]
938 fn test_node_scan_with_stats() {
939 let mut estimator = CardinalityEstimator::new();
940 estimator.add_table_stats("Person", TableStats::new(5000));
941
942 let scan = LogicalOperator::NodeScan(NodeScanOp {
943 variable: "n".to_string(),
944 label: Some("Person".to_string()),
945 input: None,
946 });
947
948 let cardinality = estimator.estimate(&scan);
949 assert!((cardinality - 5000.0).abs() < 0.001);
950 }
951
952 #[test]
953 fn test_filter_reduces_cardinality() {
954 let mut estimator = CardinalityEstimator::new();
955 estimator.add_table_stats("Person", TableStats::new(1000));
956
957 let filter = LogicalOperator::Filter(FilterOp {
958 predicate: LogicalExpression::Binary {
959 left: Box::new(LogicalExpression::Property {
960 variable: "n".to_string(),
961 property: "age".to_string(),
962 }),
963 op: BinaryOp::Eq,
964 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
965 },
966 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
967 variable: "n".to_string(),
968 label: Some("Person".to_string()),
969 input: None,
970 })),
971 });
972
973 let cardinality = estimator.estimate(&filter);
974 assert!(cardinality < 1000.0);
976 assert!(cardinality >= 1.0);
977 }
978
979 #[test]
980 fn test_join_cardinality() {
981 let mut estimator = CardinalityEstimator::new();
982 estimator.add_table_stats("Person", TableStats::new(1000));
983 estimator.add_table_stats("Company", TableStats::new(100));
984
985 let join = LogicalOperator::Join(JoinOp {
986 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
987 variable: "p".to_string(),
988 label: Some("Person".to_string()),
989 input: None,
990 })),
991 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
992 variable: "c".to_string(),
993 label: Some("Company".to_string()),
994 input: None,
995 })),
996 join_type: JoinType::Inner,
997 conditions: vec![JoinCondition {
998 left: LogicalExpression::Property {
999 variable: "p".to_string(),
1000 property: "company_id".to_string(),
1001 },
1002 right: LogicalExpression::Property {
1003 variable: "c".to_string(),
1004 property: "id".to_string(),
1005 },
1006 }],
1007 });
1008
1009 let cardinality = estimator.estimate(&join);
1010 assert!(cardinality < 1000.0 * 100.0);
1012 }
1013
1014 #[test]
1015 fn test_limit_caps_cardinality() {
1016 let mut estimator = CardinalityEstimator::new();
1017 estimator.add_table_stats("Person", TableStats::new(1000));
1018
1019 let limit = LogicalOperator::Limit(LimitOp {
1020 count: 10,
1021 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1022 variable: "n".to_string(),
1023 label: Some("Person".to_string()),
1024 input: None,
1025 })),
1026 });
1027
1028 let cardinality = estimator.estimate(&limit);
1029 assert!((cardinality - 10.0).abs() < 0.001);
1030 }
1031
1032 #[test]
1033 fn test_aggregate_reduces_cardinality() {
1034 let mut estimator = CardinalityEstimator::new();
1035 estimator.add_table_stats("Person", TableStats::new(1000));
1036
1037 let global_agg = LogicalOperator::Aggregate(AggregateOp {
1039 group_by: vec![],
1040 aggregates: vec![],
1041 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1042 variable: "n".to_string(),
1043 label: Some("Person".to_string()),
1044 input: None,
1045 })),
1046 having: None,
1047 });
1048
1049 let cardinality = estimator.estimate(&global_agg);
1050 assert!((cardinality - 1.0).abs() < 0.001);
1051
1052 let group_agg = LogicalOperator::Aggregate(AggregateOp {
1054 group_by: vec![LogicalExpression::Property {
1055 variable: "n".to_string(),
1056 property: "city".to_string(),
1057 }],
1058 aggregates: vec![],
1059 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1060 variable: "n".to_string(),
1061 label: Some("Person".to_string()),
1062 input: None,
1063 })),
1064 having: None,
1065 });
1066
1067 let cardinality = estimator.estimate(&group_agg);
1068 assert!(cardinality < 1000.0);
1070 }
1071
1072 #[test]
1073 fn test_node_scan_without_stats() {
1074 let estimator = CardinalityEstimator::new();
1075
1076 let scan = LogicalOperator::NodeScan(NodeScanOp {
1077 variable: "n".to_string(),
1078 label: Some("Unknown".to_string()),
1079 input: None,
1080 });
1081
1082 let cardinality = estimator.estimate(&scan);
1083 assert!((cardinality - 1000.0).abs() < 0.001);
1085 }
1086
1087 #[test]
1088 fn test_node_scan_no_label() {
1089 let estimator = CardinalityEstimator::new();
1090
1091 let scan = LogicalOperator::NodeScan(NodeScanOp {
1092 variable: "n".to_string(),
1093 label: None,
1094 input: None,
1095 });
1096
1097 let cardinality = estimator.estimate(&scan);
1098 assert!((cardinality - 1000.0).abs() < 0.001);
1100 }
1101
1102 #[test]
1103 fn test_filter_inequality_selectivity() {
1104 let mut estimator = CardinalityEstimator::new();
1105 estimator.add_table_stats("Person", TableStats::new(1000));
1106
1107 let filter = LogicalOperator::Filter(FilterOp {
1108 predicate: LogicalExpression::Binary {
1109 left: Box::new(LogicalExpression::Property {
1110 variable: "n".to_string(),
1111 property: "age".to_string(),
1112 }),
1113 op: BinaryOp::Ne,
1114 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1115 },
1116 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1117 variable: "n".to_string(),
1118 label: Some("Person".to_string()),
1119 input: None,
1120 })),
1121 });
1122
1123 let cardinality = estimator.estimate(&filter);
1124 assert!(cardinality > 900.0);
1126 }
1127
1128 #[test]
1129 fn test_filter_range_selectivity() {
1130 let mut estimator = CardinalityEstimator::new();
1131 estimator.add_table_stats("Person", TableStats::new(1000));
1132
1133 let filter = LogicalOperator::Filter(FilterOp {
1134 predicate: LogicalExpression::Binary {
1135 left: Box::new(LogicalExpression::Property {
1136 variable: "n".to_string(),
1137 property: "age".to_string(),
1138 }),
1139 op: BinaryOp::Gt,
1140 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1141 },
1142 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1143 variable: "n".to_string(),
1144 label: Some("Person".to_string()),
1145 input: None,
1146 })),
1147 });
1148
1149 let cardinality = estimator.estimate(&filter);
1150 assert!(cardinality < 500.0);
1152 assert!(cardinality > 100.0);
1153 }
1154
1155 #[test]
1156 fn test_filter_and_selectivity() {
1157 let mut estimator = CardinalityEstimator::new();
1158 estimator.add_table_stats("Person", TableStats::new(1000));
1159
1160 let filter = LogicalOperator::Filter(FilterOp {
1163 predicate: LogicalExpression::Binary {
1164 left: Box::new(LogicalExpression::Binary {
1165 left: Box::new(LogicalExpression::Property {
1166 variable: "n".to_string(),
1167 property: "city".to_string(),
1168 }),
1169 op: BinaryOp::Eq,
1170 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1171 }),
1172 op: BinaryOp::And,
1173 right: Box::new(LogicalExpression::Binary {
1174 left: Box::new(LogicalExpression::Property {
1175 variable: "n".to_string(),
1176 property: "age".to_string(),
1177 }),
1178 op: BinaryOp::Eq,
1179 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1180 }),
1181 },
1182 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1183 variable: "n".to_string(),
1184 label: Some("Person".to_string()),
1185 input: None,
1186 })),
1187 });
1188
1189 let cardinality = estimator.estimate(&filter);
1190 assert!(cardinality < 100.0);
1193 assert!(cardinality >= 1.0);
1194 }
1195
1196 #[test]
1197 fn test_filter_or_selectivity() {
1198 let mut estimator = CardinalityEstimator::new();
1199 estimator.add_table_stats("Person", TableStats::new(1000));
1200
1201 let filter = LogicalOperator::Filter(FilterOp {
1205 predicate: LogicalExpression::Binary {
1206 left: Box::new(LogicalExpression::Binary {
1207 left: Box::new(LogicalExpression::Property {
1208 variable: "n".to_string(),
1209 property: "city".to_string(),
1210 }),
1211 op: BinaryOp::Eq,
1212 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1213 }),
1214 op: BinaryOp::Or,
1215 right: Box::new(LogicalExpression::Binary {
1216 left: Box::new(LogicalExpression::Property {
1217 variable: "n".to_string(),
1218 property: "city".to_string(),
1219 }),
1220 op: BinaryOp::Eq,
1221 right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1222 }),
1223 },
1224 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1225 variable: "n".to_string(),
1226 label: Some("Person".to_string()),
1227 input: None,
1228 })),
1229 });
1230
1231 let cardinality = estimator.estimate(&filter);
1232 assert!(cardinality < 100.0);
1234 assert!(cardinality >= 1.0);
1235 }
1236
1237 #[test]
1238 fn test_filter_literal_true() {
1239 let mut estimator = CardinalityEstimator::new();
1240 estimator.add_table_stats("Person", TableStats::new(1000));
1241
1242 let filter = LogicalOperator::Filter(FilterOp {
1243 predicate: LogicalExpression::Literal(Value::Bool(true)),
1244 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1245 variable: "n".to_string(),
1246 label: Some("Person".to_string()),
1247 input: None,
1248 })),
1249 });
1250
1251 let cardinality = estimator.estimate(&filter);
1252 assert!((cardinality - 1000.0).abs() < 0.001);
1254 }
1255
1256 #[test]
1257 fn test_filter_literal_false() {
1258 let mut estimator = CardinalityEstimator::new();
1259 estimator.add_table_stats("Person", TableStats::new(1000));
1260
1261 let filter = LogicalOperator::Filter(FilterOp {
1262 predicate: LogicalExpression::Literal(Value::Bool(false)),
1263 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1264 variable: "n".to_string(),
1265 label: Some("Person".to_string()),
1266 input: None,
1267 })),
1268 });
1269
1270 let cardinality = estimator.estimate(&filter);
1271 assert!((cardinality - 1.0).abs() < 0.001);
1273 }
1274
1275 #[test]
1276 fn test_unary_not_selectivity() {
1277 let mut estimator = CardinalityEstimator::new();
1278 estimator.add_table_stats("Person", TableStats::new(1000));
1279
1280 let filter = LogicalOperator::Filter(FilterOp {
1281 predicate: LogicalExpression::Unary {
1282 op: UnaryOp::Not,
1283 operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1284 },
1285 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1286 variable: "n".to_string(),
1287 label: Some("Person".to_string()),
1288 input: None,
1289 })),
1290 });
1291
1292 let cardinality = estimator.estimate(&filter);
1293 assert!(cardinality < 1000.0);
1295 }
1296
1297 #[test]
1298 fn test_unary_is_null_selectivity() {
1299 let mut estimator = CardinalityEstimator::new();
1300 estimator.add_table_stats("Person", TableStats::new(1000));
1301
1302 let filter = LogicalOperator::Filter(FilterOp {
1303 predicate: LogicalExpression::Unary {
1304 op: UnaryOp::IsNull,
1305 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1306 },
1307 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1308 variable: "n".to_string(),
1309 label: Some("Person".to_string()),
1310 input: None,
1311 })),
1312 });
1313
1314 let cardinality = estimator.estimate(&filter);
1315 assert!(cardinality < 100.0);
1317 }
1318
1319 #[test]
1320 fn test_expand_cardinality() {
1321 let mut estimator = CardinalityEstimator::new();
1322 estimator.add_table_stats("Person", TableStats::new(100));
1323
1324 let expand = LogicalOperator::Expand(ExpandOp {
1325 from_variable: "a".to_string(),
1326 to_variable: "b".to_string(),
1327 edge_variable: None,
1328 direction: ExpandDirection::Outgoing,
1329 edge_type: None,
1330 min_hops: 1,
1331 max_hops: Some(1),
1332 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1333 variable: "a".to_string(),
1334 label: Some("Person".to_string()),
1335 input: None,
1336 })),
1337 path_alias: None,
1338 });
1339
1340 let cardinality = estimator.estimate(&expand);
1341 assert!(cardinality > 100.0);
1343 }
1344
1345 #[test]
1346 fn test_expand_with_edge_type_filter() {
1347 let mut estimator = CardinalityEstimator::new();
1348 estimator.add_table_stats("Person", TableStats::new(100));
1349
1350 let expand = LogicalOperator::Expand(ExpandOp {
1351 from_variable: "a".to_string(),
1352 to_variable: "b".to_string(),
1353 edge_variable: None,
1354 direction: ExpandDirection::Outgoing,
1355 edge_type: Some("KNOWS".to_string()),
1356 min_hops: 1,
1357 max_hops: Some(1),
1358 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1359 variable: "a".to_string(),
1360 label: Some("Person".to_string()),
1361 input: None,
1362 })),
1363 path_alias: None,
1364 });
1365
1366 let cardinality = estimator.estimate(&expand);
1367 assert!(cardinality > 100.0);
1369 }
1370
1371 #[test]
1372 fn test_expand_variable_length() {
1373 let mut estimator = CardinalityEstimator::new();
1374 estimator.add_table_stats("Person", TableStats::new(100));
1375
1376 let expand = LogicalOperator::Expand(ExpandOp {
1377 from_variable: "a".to_string(),
1378 to_variable: "b".to_string(),
1379 edge_variable: None,
1380 direction: ExpandDirection::Outgoing,
1381 edge_type: None,
1382 min_hops: 1,
1383 max_hops: Some(3),
1384 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1385 variable: "a".to_string(),
1386 label: Some("Person".to_string()),
1387 input: None,
1388 })),
1389 path_alias: None,
1390 });
1391
1392 let cardinality = estimator.estimate(&expand);
1393 assert!(cardinality > 500.0);
1395 }
1396
1397 #[test]
1398 fn test_join_cross_product() {
1399 let mut estimator = CardinalityEstimator::new();
1400 estimator.add_table_stats("Person", TableStats::new(100));
1401 estimator.add_table_stats("Company", TableStats::new(50));
1402
1403 let join = LogicalOperator::Join(JoinOp {
1404 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1405 variable: "p".to_string(),
1406 label: Some("Person".to_string()),
1407 input: None,
1408 })),
1409 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1410 variable: "c".to_string(),
1411 label: Some("Company".to_string()),
1412 input: None,
1413 })),
1414 join_type: JoinType::Cross,
1415 conditions: vec![],
1416 });
1417
1418 let cardinality = estimator.estimate(&join);
1419 assert!((cardinality - 5000.0).abs() < 0.001);
1421 }
1422
1423 #[test]
1424 fn test_join_left_outer() {
1425 let mut estimator = CardinalityEstimator::new();
1426 estimator.add_table_stats("Person", TableStats::new(1000));
1427 estimator.add_table_stats("Company", TableStats::new(10));
1428
1429 let join = LogicalOperator::Join(JoinOp {
1430 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1431 variable: "p".to_string(),
1432 label: Some("Person".to_string()),
1433 input: None,
1434 })),
1435 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1436 variable: "c".to_string(),
1437 label: Some("Company".to_string()),
1438 input: None,
1439 })),
1440 join_type: JoinType::Left,
1441 conditions: vec![JoinCondition {
1442 left: LogicalExpression::Variable("p".to_string()),
1443 right: LogicalExpression::Variable("c".to_string()),
1444 }],
1445 });
1446
1447 let cardinality = estimator.estimate(&join);
1448 assert!(cardinality >= 1000.0);
1450 }
1451
1452 #[test]
1453 fn test_join_semi() {
1454 let mut estimator = CardinalityEstimator::new();
1455 estimator.add_table_stats("Person", TableStats::new(1000));
1456 estimator.add_table_stats("Company", TableStats::new(100));
1457
1458 let join = LogicalOperator::Join(JoinOp {
1459 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1460 variable: "p".to_string(),
1461 label: Some("Person".to_string()),
1462 input: None,
1463 })),
1464 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1465 variable: "c".to_string(),
1466 label: Some("Company".to_string()),
1467 input: None,
1468 })),
1469 join_type: JoinType::Semi,
1470 conditions: vec![],
1471 });
1472
1473 let cardinality = estimator.estimate(&join);
1474 assert!(cardinality <= 1000.0);
1476 }
1477
1478 #[test]
1479 fn test_join_anti() {
1480 let mut estimator = CardinalityEstimator::new();
1481 estimator.add_table_stats("Person", TableStats::new(1000));
1482 estimator.add_table_stats("Company", TableStats::new(100));
1483
1484 let join = LogicalOperator::Join(JoinOp {
1485 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1486 variable: "p".to_string(),
1487 label: Some("Person".to_string()),
1488 input: None,
1489 })),
1490 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1491 variable: "c".to_string(),
1492 label: Some("Company".to_string()),
1493 input: None,
1494 })),
1495 join_type: JoinType::Anti,
1496 conditions: vec![],
1497 });
1498
1499 let cardinality = estimator.estimate(&join);
1500 assert!(cardinality <= 1000.0);
1502 assert!(cardinality >= 1.0);
1503 }
1504
1505 #[test]
1506 fn test_project_preserves_cardinality() {
1507 let mut estimator = CardinalityEstimator::new();
1508 estimator.add_table_stats("Person", TableStats::new(1000));
1509
1510 let project = LogicalOperator::Project(ProjectOp {
1511 projections: vec![Projection {
1512 expression: LogicalExpression::Variable("n".to_string()),
1513 alias: None,
1514 }],
1515 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1516 variable: "n".to_string(),
1517 label: Some("Person".to_string()),
1518 input: None,
1519 })),
1520 });
1521
1522 let cardinality = estimator.estimate(&project);
1523 assert!((cardinality - 1000.0).abs() < 0.001);
1524 }
1525
1526 #[test]
1527 fn test_sort_preserves_cardinality() {
1528 let mut estimator = CardinalityEstimator::new();
1529 estimator.add_table_stats("Person", TableStats::new(1000));
1530
1531 let sort = LogicalOperator::Sort(SortOp {
1532 keys: vec![SortKey {
1533 expression: LogicalExpression::Variable("n".to_string()),
1534 order: SortOrder::Ascending,
1535 }],
1536 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1537 variable: "n".to_string(),
1538 label: Some("Person".to_string()),
1539 input: None,
1540 })),
1541 });
1542
1543 let cardinality = estimator.estimate(&sort);
1544 assert!((cardinality - 1000.0).abs() < 0.001);
1545 }
1546
1547 #[test]
1548 fn test_distinct_reduces_cardinality() {
1549 let mut estimator = CardinalityEstimator::new();
1550 estimator.add_table_stats("Person", TableStats::new(1000));
1551
1552 let distinct = LogicalOperator::Distinct(DistinctOp {
1553 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1554 variable: "n".to_string(),
1555 label: Some("Person".to_string()),
1556 input: None,
1557 })),
1558 columns: None,
1559 });
1560
1561 let cardinality = estimator.estimate(&distinct);
1562 assert!((cardinality - 500.0).abs() < 0.001);
1564 }
1565
1566 #[test]
1567 fn test_skip_reduces_cardinality() {
1568 let mut estimator = CardinalityEstimator::new();
1569 estimator.add_table_stats("Person", TableStats::new(1000));
1570
1571 let skip = LogicalOperator::Skip(SkipOp {
1572 count: 100,
1573 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1574 variable: "n".to_string(),
1575 label: Some("Person".to_string()),
1576 input: None,
1577 })),
1578 });
1579
1580 let cardinality = estimator.estimate(&skip);
1581 assert!((cardinality - 900.0).abs() < 0.001);
1582 }
1583
1584 #[test]
1585 fn test_return_preserves_cardinality() {
1586 let mut estimator = CardinalityEstimator::new();
1587 estimator.add_table_stats("Person", TableStats::new(1000));
1588
1589 let ret = LogicalOperator::Return(ReturnOp {
1590 items: vec![ReturnItem {
1591 expression: LogicalExpression::Variable("n".to_string()),
1592 alias: None,
1593 }],
1594 distinct: false,
1595 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1596 variable: "n".to_string(),
1597 label: Some("Person".to_string()),
1598 input: None,
1599 })),
1600 });
1601
1602 let cardinality = estimator.estimate(&ret);
1603 assert!((cardinality - 1000.0).abs() < 0.001);
1604 }
1605
1606 #[test]
1607 fn test_empty_cardinality() {
1608 let estimator = CardinalityEstimator::new();
1609 let cardinality = estimator.estimate(&LogicalOperator::Empty);
1610 assert!((cardinality).abs() < 0.001);
1611 }
1612
1613 #[test]
1614 fn test_table_stats_with_column() {
1615 let stats = TableStats::new(1000).with_column(
1616 "age",
1617 ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1618 );
1619
1620 assert_eq!(stats.row_count, 1000);
1621 let col = stats.columns.get("age").unwrap();
1622 assert_eq!(col.distinct_count, 50);
1623 assert_eq!(col.null_count, 10);
1624 assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1625 assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1626 }
1627
1628 #[test]
1629 fn test_estimator_default() {
1630 let estimator = CardinalityEstimator::default();
1631 let scan = LogicalOperator::NodeScan(NodeScanOp {
1632 variable: "n".to_string(),
1633 label: None,
1634 input: None,
1635 });
1636 let cardinality = estimator.estimate(&scan);
1637 assert!((cardinality - 1000.0).abs() < 0.001);
1638 }
1639
1640 #[test]
1641 fn test_set_avg_fanout() {
1642 let mut estimator = CardinalityEstimator::new();
1643 estimator.add_table_stats("Person", TableStats::new(100));
1644 estimator.set_avg_fanout(5.0);
1645
1646 let expand = LogicalOperator::Expand(ExpandOp {
1647 from_variable: "a".to_string(),
1648 to_variable: "b".to_string(),
1649 edge_variable: None,
1650 direction: ExpandDirection::Outgoing,
1651 edge_type: None,
1652 min_hops: 1,
1653 max_hops: Some(1),
1654 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1655 variable: "a".to_string(),
1656 label: Some("Person".to_string()),
1657 input: None,
1658 })),
1659 path_alias: None,
1660 });
1661
1662 let cardinality = estimator.estimate(&expand);
1663 assert!((cardinality - 500.0).abs() < 0.001);
1665 }
1666
1667 #[test]
1668 fn test_multiple_group_by_keys_reduce_cardinality() {
1669 let mut estimator = CardinalityEstimator::new();
1673 estimator.add_table_stats("Person", TableStats::new(10000));
1674
1675 let single_group = LogicalOperator::Aggregate(AggregateOp {
1676 group_by: vec![LogicalExpression::Property {
1677 variable: "n".to_string(),
1678 property: "city".to_string(),
1679 }],
1680 aggregates: vec![],
1681 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1682 variable: "n".to_string(),
1683 label: Some("Person".to_string()),
1684 input: None,
1685 })),
1686 having: None,
1687 });
1688
1689 let multi_group = LogicalOperator::Aggregate(AggregateOp {
1690 group_by: vec![
1691 LogicalExpression::Property {
1692 variable: "n".to_string(),
1693 property: "city".to_string(),
1694 },
1695 LogicalExpression::Property {
1696 variable: "n".to_string(),
1697 property: "country".to_string(),
1698 },
1699 ],
1700 aggregates: vec![],
1701 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1702 variable: "n".to_string(),
1703 label: Some("Person".to_string()),
1704 input: None,
1705 })),
1706 having: None,
1707 });
1708
1709 let single_card = estimator.estimate(&single_group);
1710 let multi_card = estimator.estimate(&multi_group);
1711
1712 assert!(single_card < 10000.0);
1714 assert!(multi_card < 10000.0);
1715 assert!(single_card >= 1.0);
1717 assert!(multi_card >= 1.0);
1718 }
1719
1720 #[test]
1723 fn test_histogram_build_uniform() {
1724 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1726 let histogram = EquiDepthHistogram::build(&values, 10);
1727
1728 assert_eq!(histogram.num_buckets(), 10);
1729 assert_eq!(histogram.total_rows(), 100);
1730
1731 for bucket in histogram.buckets() {
1733 assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1734 }
1735 }
1736
1737 #[test]
1738 fn test_histogram_build_skewed() {
1739 let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1741 values.extend((0..20).map(|i| 1000.0 + i as f64));
1742 let histogram = EquiDepthHistogram::build(&values, 5);
1743
1744 assert_eq!(histogram.num_buckets(), 5);
1745 assert_eq!(histogram.total_rows(), 100);
1746
1747 for bucket in histogram.buckets() {
1749 assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1750 }
1751 }
1752
1753 #[test]
1754 fn test_histogram_range_selectivity_full() {
1755 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1756 let histogram = EquiDepthHistogram::build(&values, 10);
1757
1758 let selectivity = histogram.range_selectivity(None, None);
1760 assert!((selectivity - 1.0).abs() < 0.01);
1761 }
1762
1763 #[test]
1764 fn test_histogram_range_selectivity_half() {
1765 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1766 let histogram = EquiDepthHistogram::build(&values, 10);
1767
1768 let selectivity = histogram.range_selectivity(Some(50.0), None);
1770 assert!(selectivity > 0.4 && selectivity < 0.6);
1771 }
1772
1773 #[test]
1774 fn test_histogram_range_selectivity_quarter() {
1775 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1776 let histogram = EquiDepthHistogram::build(&values, 10);
1777
1778 let selectivity = histogram.range_selectivity(None, Some(25.0));
1780 assert!(selectivity > 0.2 && selectivity < 0.3);
1781 }
1782
1783 #[test]
1784 fn test_histogram_equality_selectivity() {
1785 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1786 let histogram = EquiDepthHistogram::build(&values, 10);
1787
1788 let selectivity = histogram.equality_selectivity(50.0);
1790 assert!(selectivity > 0.005 && selectivity < 0.02);
1791 }
1792
1793 #[test]
1794 fn test_histogram_empty() {
1795 let histogram = EquiDepthHistogram::build(&[], 10);
1796
1797 assert_eq!(histogram.num_buckets(), 0);
1798 assert_eq!(histogram.total_rows(), 0);
1799
1800 let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1802 assert!((selectivity - 0.33).abs() < 0.01);
1803 }
1804
1805 #[test]
1806 fn test_histogram_bucket_overlap() {
1807 let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1808
1809 assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1811
1812 assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1814
1815 assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1817
1818 assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1820
1821 assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1823 }
1824
1825 #[test]
1826 fn test_column_stats_from_values() {
1827 let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
1828 let stats = ColumnStats::from_values(values, 4);
1829
1830 assert_eq!(stats.distinct_count, 5); assert!(stats.min_value.is_some());
1832 assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
1833 assert!(stats.max_value.is_some());
1834 assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
1835 assert!(stats.histogram.is_some());
1836 }
1837
1838 #[test]
1839 fn test_column_stats_with_histogram_builder() {
1840 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1841 let histogram = EquiDepthHistogram::build(&values, 10);
1842
1843 let stats = ColumnStats::new(100)
1844 .with_range(0.0, 99.0)
1845 .with_histogram(histogram);
1846
1847 assert!(stats.histogram.is_some());
1848 assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
1849 }
1850
1851 #[test]
1852 fn test_filter_with_histogram_stats() {
1853 let mut estimator = CardinalityEstimator::new();
1854
1855 let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
1857 let histogram = EquiDepthHistogram::build(&age_values, 10);
1858 let age_stats = ColumnStats::new(62)
1859 .with_range(18.0, 79.0)
1860 .with_histogram(histogram);
1861
1862 estimator.add_table_stats(
1863 "Person",
1864 TableStats::new(1000).with_column("age", age_stats),
1865 );
1866
1867 let filter = LogicalOperator::Filter(FilterOp {
1870 predicate: LogicalExpression::Binary {
1871 left: Box::new(LogicalExpression::Property {
1872 variable: "n".to_string(),
1873 property: "age".to_string(),
1874 }),
1875 op: BinaryOp::Gt,
1876 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1877 },
1878 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1879 variable: "n".to_string(),
1880 label: Some("Person".to_string()),
1881 input: None,
1882 })),
1883 });
1884
1885 let cardinality = estimator.estimate(&filter);
1886
1887 assert!(cardinality > 300.0 && cardinality < 600.0);
1890 }
1891
1892 #[test]
1893 fn test_filter_equality_with_histogram() {
1894 let mut estimator = CardinalityEstimator::new();
1895
1896 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1898 let histogram = EquiDepthHistogram::build(&values, 10);
1899 let stats = ColumnStats::new(100)
1900 .with_range(0.0, 99.0)
1901 .with_histogram(histogram);
1902
1903 estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
1904
1905 let filter = LogicalOperator::Filter(FilterOp {
1907 predicate: LogicalExpression::Binary {
1908 left: Box::new(LogicalExpression::Property {
1909 variable: "d".to_string(),
1910 property: "value".to_string(),
1911 }),
1912 op: BinaryOp::Eq,
1913 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1914 },
1915 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1916 variable: "d".to_string(),
1917 label: Some("Data".to_string()),
1918 input: None,
1919 })),
1920 });
1921
1922 let cardinality = estimator.estimate(&filter);
1923
1924 assert!(cardinality >= 1.0 && cardinality < 50.0);
1927 }
1928
1929 #[test]
1930 fn test_histogram_min_max() {
1931 let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
1932 let histogram = EquiDepthHistogram::build(&values, 2);
1933
1934 assert_eq!(histogram.min_value(), Some(5.0));
1935 assert!(histogram.max_value().is_some());
1937 }
1938}