1use crate::query::plan::{
18 AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19 LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
20 VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31 pub lower_bound: f64,
33 pub upper_bound: f64,
35 pub frequency: u64,
37 pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42 #[must_use]
44 pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45 Self {
46 lower_bound,
47 upper_bound,
48 frequency,
49 distinct_count,
50 }
51 }
52
53 #[must_use]
55 pub fn width(&self) -> f64 {
56 self.upper_bound - self.lower_bound
57 }
58
59 #[must_use]
61 pub fn contains(&self, value: f64) -> bool {
62 value >= self.lower_bound && value < self.upper_bound
63 }
64
65 #[must_use]
67 pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68 let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69 let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71 let bucket_width = self.width();
72 if bucket_width <= 0.0 {
73 return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74 1.0
75 } else {
76 0.0
77 };
78 }
79
80 let overlap = (effective_upper - effective_lower).max(0.0);
81 (overlap / bucket_width).min(1.0)
82 }
83}
84
85#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106 buckets: Vec<HistogramBucket>,
108 total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113 #[must_use]
115 pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116 let total_rows = buckets.iter().map(|b| b.frequency).sum();
117 Self {
118 buckets,
119 total_rows,
120 }
121 }
122
123 #[must_use]
132 pub fn build(values: &[f64], num_buckets: usize) -> Self {
133 if values.is_empty() || num_buckets == 0 {
134 return Self {
135 buckets: Vec::new(),
136 total_rows: 0,
137 };
138 }
139
140 let num_buckets = num_buckets.min(values.len());
141 let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142 let mut buckets = Vec::with_capacity(num_buckets);
143
144 let mut start_idx = 0;
145 while start_idx < values.len() {
146 let end_idx = (start_idx + rows_per_bucket).min(values.len());
147 let lower_bound = values[start_idx];
148 let upper_bound = if end_idx < values.len() {
149 values[end_idx]
150 } else {
151 values[end_idx - 1] + 1.0
153 };
154
155 let bucket_values = &values[start_idx..end_idx];
157 let distinct_count = count_distinct(bucket_values);
158
159 buckets.push(HistogramBucket::new(
160 lower_bound,
161 upper_bound,
162 (end_idx - start_idx) as u64,
163 distinct_count,
164 ));
165
166 start_idx = end_idx;
167 }
168
169 Self::new(buckets)
170 }
171
172 #[must_use]
174 pub fn num_buckets(&self) -> usize {
175 self.buckets.len()
176 }
177
178 #[must_use]
180 pub fn total_rows(&self) -> u64 {
181 self.total_rows
182 }
183
184 #[must_use]
186 pub fn buckets(&self) -> &[HistogramBucket] {
187 &self.buckets
188 }
189
190 #[must_use]
199 pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200 if self.buckets.is_empty() || self.total_rows == 0 {
201 return 0.33; }
203
204 let mut matching_rows = 0.0;
205
206 for bucket in &self.buckets {
207 let bucket_lower = bucket.lower_bound;
209 let bucket_upper = bucket.upper_bound;
210
211 if let Some(l) = lower {
213 if bucket_upper <= l {
214 continue;
215 }
216 }
217 if let Some(u) = upper {
218 if bucket_lower >= u {
219 continue;
220 }
221 }
222
223 let overlap = bucket.overlap_fraction(lower, upper);
225 matching_rows += overlap * bucket.frequency as f64;
226 }
227
228 (matching_rows / self.total_rows as f64).min(1.0).max(0.0)
229 }
230
231 #[must_use]
235 pub fn equality_selectivity(&self, value: f64) -> f64 {
236 if self.buckets.is_empty() || self.total_rows == 0 {
237 return 0.01; }
239
240 for bucket in &self.buckets {
242 if bucket.contains(value) {
243 if bucket.distinct_count > 0 {
245 return (bucket.frequency as f64
246 / bucket.distinct_count as f64
247 / self.total_rows as f64)
248 .min(1.0);
249 }
250 }
251 }
252
253 0.001
255 }
256
257 #[must_use]
259 pub fn min_value(&self) -> Option<f64> {
260 self.buckets.first().map(|b| b.lower_bound)
261 }
262
263 #[must_use]
265 pub fn max_value(&self) -> Option<f64> {
266 self.buckets.last().map(|b| b.upper_bound)
267 }
268}
269
270fn count_distinct(sorted_values: &[f64]) -> u64 {
272 if sorted_values.is_empty() {
273 return 0;
274 }
275
276 let mut count = 1u64;
277 let mut prev = sorted_values[0];
278
279 for &val in &sorted_values[1..] {
280 if (val - prev).abs() > f64::EPSILON {
281 count += 1;
282 prev = val;
283 }
284 }
285
286 count
287}
288
289#[derive(Debug, Clone)]
291pub struct TableStats {
292 pub row_count: u64,
294 pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299 #[must_use]
301 pub fn new(row_count: u64) -> Self {
302 Self {
303 row_count,
304 columns: HashMap::new(),
305 }
306 }
307
308 pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310 self.columns.insert(name.to_string(), stats);
311 self
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct ColumnStats {
318 pub distinct_count: u64,
320 pub null_count: u64,
322 pub min_value: Option<f64>,
324 pub max_value: Option<f64>,
326 pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331 #[must_use]
333 pub fn new(distinct_count: u64) -> Self {
334 Self {
335 distinct_count,
336 null_count: 0,
337 min_value: None,
338 max_value: None,
339 histogram: None,
340 }
341 }
342
343 #[must_use]
345 pub fn with_nulls(mut self, null_count: u64) -> Self {
346 self.null_count = null_count;
347 self
348 }
349
350 #[must_use]
352 pub fn with_range(mut self, min: f64, max: f64) -> Self {
353 self.min_value = Some(min);
354 self.max_value = Some(max);
355 self
356 }
357
358 #[must_use]
360 pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361 self.histogram = Some(histogram);
362 self
363 }
364
365 #[must_use]
373 pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374 if values.is_empty() {
375 return Self::new(0);
376 }
377
378 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381 let min = values.first().copied();
382 let max = values.last().copied();
383 let distinct_count = count_distinct(&values);
384 let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386 Self {
387 distinct_count,
388 null_count: 0,
389 min_value: min,
390 max_value: max,
391 histogram: Some(histogram),
392 }
393 }
394}
395
396pub struct CardinalityEstimator {
398 table_stats: HashMap<String, TableStats>,
400 default_row_count: u64,
402 default_selectivity: f64,
404 avg_fanout: f64,
406}
407
408impl CardinalityEstimator {
409 #[must_use]
411 pub fn new() -> Self {
412 Self {
413 table_stats: HashMap::new(),
414 default_row_count: 1000,
415 default_selectivity: 0.1,
416 avg_fanout: 10.0,
417 }
418 }
419
420 pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
422 self.table_stats.insert(name.to_string(), stats);
423 }
424
425 pub fn set_avg_fanout(&mut self, fanout: f64) {
427 self.avg_fanout = fanout;
428 }
429
430 #[must_use]
432 pub fn estimate(&self, op: &LogicalOperator) -> f64 {
433 match op {
434 LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
435 LogicalOperator::Filter(filter) => self.estimate_filter(filter),
436 LogicalOperator::Project(project) => self.estimate_project(project),
437 LogicalOperator::Expand(expand) => self.estimate_expand(expand),
438 LogicalOperator::Join(join) => self.estimate_join(join),
439 LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
440 LogicalOperator::Sort(sort) => self.estimate_sort(sort),
441 LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
442 LogicalOperator::Limit(limit) => self.estimate_limit(limit),
443 LogicalOperator::Skip(skip) => self.estimate_skip(skip),
444 LogicalOperator::Return(ret) => self.estimate(&ret.input),
445 LogicalOperator::Empty => 0.0,
446 LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
447 LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
448 _ => self.default_row_count as f64,
449 }
450 }
451
452 fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
454 if let Some(label) = &scan.label {
455 if let Some(stats) = self.table_stats.get(label) {
456 return stats.row_count as f64;
457 }
458 }
459 self.default_row_count as f64
461 }
462
463 fn estimate_filter(&self, filter: &FilterOp) -> f64 {
465 let input_cardinality = self.estimate(&filter.input);
466 let selectivity = self.estimate_selectivity(&filter.predicate);
467 (input_cardinality * selectivity).max(1.0)
468 }
469
470 fn estimate_project(&self, project: &ProjectOp) -> f64 {
472 self.estimate(&project.input)
473 }
474
475 fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
477 let input_cardinality = self.estimate(&expand.input);
478
479 let fanout = if expand.edge_type.is_some() {
481 self.avg_fanout * 0.5
483 } else {
484 self.avg_fanout
485 };
486
487 let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
489 let min = expand.min_hops as f64;
490 let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
491 (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
493 } else {
494 fanout
495 };
496
497 (input_cardinality * path_multiplier).max(1.0)
498 }
499
500 fn estimate_join(&self, join: &JoinOp) -> f64 {
502 let left_card = self.estimate(&join.left);
503 let right_card = self.estimate(&join.right);
504
505 match join.join_type {
506 JoinType::Cross => left_card * right_card,
507 JoinType::Inner => {
508 let selectivity = if join.conditions.is_empty() {
510 1.0 } else {
512 0.1_f64.powi(join.conditions.len() as i32)
514 };
515 (left_card * right_card * selectivity).max(1.0)
516 }
517 JoinType::Left => {
518 let inner_card = self.estimate_join(&JoinOp {
520 left: join.left.clone(),
521 right: join.right.clone(),
522 join_type: JoinType::Inner,
523 conditions: join.conditions.clone(),
524 });
525 inner_card.max(left_card)
526 }
527 JoinType::Right => {
528 let inner_card = self.estimate_join(&JoinOp {
530 left: join.left.clone(),
531 right: join.right.clone(),
532 join_type: JoinType::Inner,
533 conditions: join.conditions.clone(),
534 });
535 inner_card.max(right_card)
536 }
537 JoinType::Full => {
538 let inner_card = self.estimate_join(&JoinOp {
540 left: join.left.clone(),
541 right: join.right.clone(),
542 join_type: JoinType::Inner,
543 conditions: join.conditions.clone(),
544 });
545 inner_card.max(left_card.max(right_card))
546 }
547 JoinType::Semi => {
548 (left_card * self.default_selectivity).max(1.0)
550 }
551 JoinType::Anti => {
552 (left_card * (1.0 - self.default_selectivity)).max(1.0)
554 }
555 }
556 }
557
558 fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
560 let input_cardinality = self.estimate(&agg.input);
561
562 if agg.group_by.is_empty() {
563 1.0
565 } else {
566 let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
569 (input_cardinality / group_reduction).max(1.0)
570 }
571 }
572
573 fn estimate_sort(&self, sort: &SortOp) -> f64 {
575 self.estimate(&sort.input)
576 }
577
578 fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
580 let input_cardinality = self.estimate(&distinct.input);
581 (input_cardinality * 0.5).max(1.0)
583 }
584
585 fn estimate_limit(&self, limit: &LimitOp) -> f64 {
587 let input_cardinality = self.estimate(&limit.input);
588 (limit.count as f64).min(input_cardinality)
589 }
590
591 fn estimate_skip(&self, skip: &SkipOp) -> f64 {
593 let input_cardinality = self.estimate(&skip.input);
594 (input_cardinality - skip.count as f64).max(0.0)
595 }
596
597 fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
602 let base_k = scan.k as f64;
603
604 let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
606 0.7
608 } else {
609 1.0
610 };
611
612 (base_k * selectivity).max(1.0)
613 }
614
615 fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
619 let input_cardinality = self.estimate(&join.input);
620 let k = join.k as f64;
621
622 let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
624 0.7
625 } else {
626 1.0
627 };
628
629 (input_cardinality * k * selectivity).max(1.0)
630 }
631
632 fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
634 match expr {
635 LogicalExpression::Binary { left, op, right } => {
636 self.estimate_binary_selectivity(left, *op, right)
637 }
638 LogicalExpression::Unary { op, operand } => {
639 self.estimate_unary_selectivity(*op, operand)
640 }
641 LogicalExpression::Literal(value) => {
642 if let grafeo_common::types::Value::Bool(b) = value {
644 if *b { 1.0 } else { 0.0 }
645 } else {
646 self.default_selectivity
647 }
648 }
649 _ => self.default_selectivity,
650 }
651 }
652
653 fn estimate_binary_selectivity(
655 &self,
656 left: &LogicalExpression,
657 op: BinaryOp,
658 right: &LogicalExpression,
659 ) -> f64 {
660 match op {
661 BinaryOp::Eq => {
663 if let Some(selectivity) = self.try_equality_selectivity(left, right) {
664 return selectivity;
665 }
666 0.01
667 }
668 BinaryOp::Ne => 0.99,
670 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
672 if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
673 return selectivity;
674 }
675 0.33
676 }
677 BinaryOp::And => {
679 let left_sel = self.estimate_selectivity(left);
680 let right_sel = self.estimate_selectivity(right);
681 left_sel * right_sel
683 }
684 BinaryOp::Or => {
685 let left_sel = self.estimate_selectivity(left);
686 let right_sel = self.estimate_selectivity(right);
687 (left_sel + right_sel - left_sel * right_sel).min(1.0)
690 }
691 BinaryOp::StartsWith => 0.1,
693 BinaryOp::EndsWith => 0.1,
694 BinaryOp::Contains => 0.1,
695 BinaryOp::Like => 0.1,
696 BinaryOp::In => 0.1,
698 _ => self.default_selectivity,
700 }
701 }
702
703 fn try_equality_selectivity(
705 &self,
706 left: &LogicalExpression,
707 right: &LogicalExpression,
708 ) -> Option<f64> {
709 let (label, column, value) = self.extract_column_and_value(left, right)?;
711
712 let stats = self.get_column_stats(&label, &column)?;
714
715 if let Some(ref histogram) = stats.histogram {
717 return Some(histogram.equality_selectivity(value));
718 }
719
720 if stats.distinct_count > 0 {
722 return Some(1.0 / stats.distinct_count as f64);
723 }
724
725 None
726 }
727
728 fn try_range_selectivity(
730 &self,
731 left: &LogicalExpression,
732 op: BinaryOp,
733 right: &LogicalExpression,
734 ) -> Option<f64> {
735 let (label, column, value) = self.extract_column_and_value(left, right)?;
737
738 let stats = self.get_column_stats(&label, &column)?;
740
741 let (lower, upper) = match op {
743 BinaryOp::Lt => (None, Some(value)),
744 BinaryOp::Le => (None, Some(value + f64::EPSILON)),
745 BinaryOp::Gt => (Some(value + f64::EPSILON), None),
746 BinaryOp::Ge => (Some(value), None),
747 _ => return None,
748 };
749
750 if let Some(ref histogram) = stats.histogram {
752 return Some(histogram.range_selectivity(lower, upper));
753 }
754
755 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
757 let range = max - min;
758 if range <= 0.0 {
759 return Some(1.0);
760 }
761
762 let effective_lower = lower.unwrap_or(min).max(min);
763 let effective_upper = upper.unwrap_or(max).min(max);
764 let overlap = (effective_upper - effective_lower).max(0.0);
765 return Some((overlap / range).min(1.0).max(0.0));
766 }
767
768 None
769 }
770
771 fn extract_column_and_value(
776 &self,
777 left: &LogicalExpression,
778 right: &LogicalExpression,
779 ) -> Option<(String, String, f64)> {
780 if let Some(result) = self.try_extract_property_literal(left, right) {
782 return Some(result);
783 }
784
785 self.try_extract_property_literal(right, left)
787 }
788
789 fn try_extract_property_literal(
791 &self,
792 property_expr: &LogicalExpression,
793 literal_expr: &LogicalExpression,
794 ) -> Option<(String, String, f64)> {
795 let (variable, property) = match property_expr {
797 LogicalExpression::Property { variable, property } => {
798 (variable.clone(), property.clone())
799 }
800 _ => return None,
801 };
802
803 let value = match literal_expr {
805 LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
806 LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
807 _ => return None,
808 };
809
810 for label in self.table_stats.keys() {
814 if let Some(stats) = self.table_stats.get(label) {
815 if stats.columns.contains_key(&property) {
816 return Some((label.clone(), property, value));
817 }
818 }
819 }
820
821 Some((variable, property, value))
823 }
824
825 fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
827 match op {
828 UnaryOp::Not => 1.0 - self.default_selectivity,
829 UnaryOp::IsNull => 0.05, UnaryOp::IsNotNull => 0.95,
831 UnaryOp::Neg => 1.0, }
833 }
834
835 fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
837 self.table_stats.get(label)?.columns.get(column)
838 }
839
840 #[allow(dead_code)]
842 fn estimate_equality_with_stats(&self, label: &str, column: &str) -> f64 {
843 if let Some(stats) = self.get_column_stats(label, column) {
844 if stats.distinct_count > 0 {
845 return 1.0 / stats.distinct_count as f64;
846 }
847 }
848 0.01 }
850
851 #[allow(dead_code)]
853 fn estimate_range_with_stats(
854 &self,
855 label: &str,
856 column: &str,
857 lower: Option<f64>,
858 upper: Option<f64>,
859 ) -> f64 {
860 if let Some(stats) = self.get_column_stats(label, column) {
861 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
862 let range = max - min;
863 if range <= 0.0 {
864 return 1.0;
865 }
866
867 let effective_lower = lower.unwrap_or(min).max(min);
868 let effective_upper = upper.unwrap_or(max).min(max);
869
870 let overlap = (effective_upper - effective_lower).max(0.0);
871 return (overlap / range).min(1.0).max(0.0);
872 }
873 }
874 0.33 }
876}
877
878impl Default for CardinalityEstimator {
879 fn default() -> Self {
880 Self::new()
881 }
882}
883
884#[cfg(test)]
885mod tests {
886 use super::*;
887 use crate::query::plan::{
888 DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, ProjectOp,
889 Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
890 };
891 use grafeo_common::types::Value;
892
893 #[test]
894 fn test_node_scan_with_stats() {
895 let mut estimator = CardinalityEstimator::new();
896 estimator.add_table_stats("Person", TableStats::new(5000));
897
898 let scan = LogicalOperator::NodeScan(NodeScanOp {
899 variable: "n".to_string(),
900 label: Some("Person".to_string()),
901 input: None,
902 });
903
904 let cardinality = estimator.estimate(&scan);
905 assert!((cardinality - 5000.0).abs() < 0.001);
906 }
907
908 #[test]
909 fn test_filter_reduces_cardinality() {
910 let mut estimator = CardinalityEstimator::new();
911 estimator.add_table_stats("Person", TableStats::new(1000));
912
913 let filter = LogicalOperator::Filter(FilterOp {
914 predicate: LogicalExpression::Binary {
915 left: Box::new(LogicalExpression::Property {
916 variable: "n".to_string(),
917 property: "age".to_string(),
918 }),
919 op: BinaryOp::Eq,
920 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
921 },
922 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
923 variable: "n".to_string(),
924 label: Some("Person".to_string()),
925 input: None,
926 })),
927 });
928
929 let cardinality = estimator.estimate(&filter);
930 assert!(cardinality < 1000.0);
932 assert!(cardinality >= 1.0);
933 }
934
935 #[test]
936 fn test_join_cardinality() {
937 let mut estimator = CardinalityEstimator::new();
938 estimator.add_table_stats("Person", TableStats::new(1000));
939 estimator.add_table_stats("Company", TableStats::new(100));
940
941 let join = LogicalOperator::Join(JoinOp {
942 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
943 variable: "p".to_string(),
944 label: Some("Person".to_string()),
945 input: None,
946 })),
947 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
948 variable: "c".to_string(),
949 label: Some("Company".to_string()),
950 input: None,
951 })),
952 join_type: JoinType::Inner,
953 conditions: vec![JoinCondition {
954 left: LogicalExpression::Property {
955 variable: "p".to_string(),
956 property: "company_id".to_string(),
957 },
958 right: LogicalExpression::Property {
959 variable: "c".to_string(),
960 property: "id".to_string(),
961 },
962 }],
963 });
964
965 let cardinality = estimator.estimate(&join);
966 assert!(cardinality < 1000.0 * 100.0);
968 }
969
970 #[test]
971 fn test_limit_caps_cardinality() {
972 let mut estimator = CardinalityEstimator::new();
973 estimator.add_table_stats("Person", TableStats::new(1000));
974
975 let limit = LogicalOperator::Limit(LimitOp {
976 count: 10,
977 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
978 variable: "n".to_string(),
979 label: Some("Person".to_string()),
980 input: None,
981 })),
982 });
983
984 let cardinality = estimator.estimate(&limit);
985 assert!((cardinality - 10.0).abs() < 0.001);
986 }
987
988 #[test]
989 fn test_aggregate_reduces_cardinality() {
990 let mut estimator = CardinalityEstimator::new();
991 estimator.add_table_stats("Person", TableStats::new(1000));
992
993 let global_agg = LogicalOperator::Aggregate(AggregateOp {
995 group_by: vec![],
996 aggregates: vec![],
997 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
998 variable: "n".to_string(),
999 label: Some("Person".to_string()),
1000 input: None,
1001 })),
1002 having: None,
1003 });
1004
1005 let cardinality = estimator.estimate(&global_agg);
1006 assert!((cardinality - 1.0).abs() < 0.001);
1007
1008 let group_agg = LogicalOperator::Aggregate(AggregateOp {
1010 group_by: vec![LogicalExpression::Property {
1011 variable: "n".to_string(),
1012 property: "city".to_string(),
1013 }],
1014 aggregates: vec![],
1015 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1016 variable: "n".to_string(),
1017 label: Some("Person".to_string()),
1018 input: None,
1019 })),
1020 having: None,
1021 });
1022
1023 let cardinality = estimator.estimate(&group_agg);
1024 assert!(cardinality < 1000.0);
1026 }
1027
1028 #[test]
1029 fn test_node_scan_without_stats() {
1030 let estimator = CardinalityEstimator::new();
1031
1032 let scan = LogicalOperator::NodeScan(NodeScanOp {
1033 variable: "n".to_string(),
1034 label: Some("Unknown".to_string()),
1035 input: None,
1036 });
1037
1038 let cardinality = estimator.estimate(&scan);
1039 assert!((cardinality - 1000.0).abs() < 0.001);
1041 }
1042
1043 #[test]
1044 fn test_node_scan_no_label() {
1045 let estimator = CardinalityEstimator::new();
1046
1047 let scan = LogicalOperator::NodeScan(NodeScanOp {
1048 variable: "n".to_string(),
1049 label: None,
1050 input: None,
1051 });
1052
1053 let cardinality = estimator.estimate(&scan);
1054 assert!((cardinality - 1000.0).abs() < 0.001);
1056 }
1057
1058 #[test]
1059 fn test_filter_inequality_selectivity() {
1060 let mut estimator = CardinalityEstimator::new();
1061 estimator.add_table_stats("Person", TableStats::new(1000));
1062
1063 let filter = LogicalOperator::Filter(FilterOp {
1064 predicate: LogicalExpression::Binary {
1065 left: Box::new(LogicalExpression::Property {
1066 variable: "n".to_string(),
1067 property: "age".to_string(),
1068 }),
1069 op: BinaryOp::Ne,
1070 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1071 },
1072 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1073 variable: "n".to_string(),
1074 label: Some("Person".to_string()),
1075 input: None,
1076 })),
1077 });
1078
1079 let cardinality = estimator.estimate(&filter);
1080 assert!(cardinality > 900.0);
1082 }
1083
1084 #[test]
1085 fn test_filter_range_selectivity() {
1086 let mut estimator = CardinalityEstimator::new();
1087 estimator.add_table_stats("Person", TableStats::new(1000));
1088
1089 let filter = LogicalOperator::Filter(FilterOp {
1090 predicate: LogicalExpression::Binary {
1091 left: Box::new(LogicalExpression::Property {
1092 variable: "n".to_string(),
1093 property: "age".to_string(),
1094 }),
1095 op: BinaryOp::Gt,
1096 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1097 },
1098 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1099 variable: "n".to_string(),
1100 label: Some("Person".to_string()),
1101 input: None,
1102 })),
1103 });
1104
1105 let cardinality = estimator.estimate(&filter);
1106 assert!(cardinality < 500.0);
1108 assert!(cardinality > 100.0);
1109 }
1110
1111 #[test]
1112 fn test_filter_and_selectivity() {
1113 let mut estimator = CardinalityEstimator::new();
1114 estimator.add_table_stats("Person", TableStats::new(1000));
1115
1116 let filter = LogicalOperator::Filter(FilterOp {
1119 predicate: LogicalExpression::Binary {
1120 left: Box::new(LogicalExpression::Binary {
1121 left: Box::new(LogicalExpression::Property {
1122 variable: "n".to_string(),
1123 property: "city".to_string(),
1124 }),
1125 op: BinaryOp::Eq,
1126 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1127 }),
1128 op: BinaryOp::And,
1129 right: Box::new(LogicalExpression::Binary {
1130 left: Box::new(LogicalExpression::Property {
1131 variable: "n".to_string(),
1132 property: "age".to_string(),
1133 }),
1134 op: BinaryOp::Eq,
1135 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1136 }),
1137 },
1138 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1139 variable: "n".to_string(),
1140 label: Some("Person".to_string()),
1141 input: None,
1142 })),
1143 });
1144
1145 let cardinality = estimator.estimate(&filter);
1146 assert!(cardinality < 100.0);
1149 assert!(cardinality >= 1.0);
1150 }
1151
1152 #[test]
1153 fn test_filter_or_selectivity() {
1154 let mut estimator = CardinalityEstimator::new();
1155 estimator.add_table_stats("Person", TableStats::new(1000));
1156
1157 let filter = LogicalOperator::Filter(FilterOp {
1161 predicate: LogicalExpression::Binary {
1162 left: Box::new(LogicalExpression::Binary {
1163 left: Box::new(LogicalExpression::Property {
1164 variable: "n".to_string(),
1165 property: "city".to_string(),
1166 }),
1167 op: BinaryOp::Eq,
1168 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1169 }),
1170 op: BinaryOp::Or,
1171 right: Box::new(LogicalExpression::Binary {
1172 left: Box::new(LogicalExpression::Property {
1173 variable: "n".to_string(),
1174 property: "city".to_string(),
1175 }),
1176 op: BinaryOp::Eq,
1177 right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1178 }),
1179 },
1180 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1181 variable: "n".to_string(),
1182 label: Some("Person".to_string()),
1183 input: None,
1184 })),
1185 });
1186
1187 let cardinality = estimator.estimate(&filter);
1188 assert!(cardinality < 100.0);
1190 assert!(cardinality >= 1.0);
1191 }
1192
1193 #[test]
1194 fn test_filter_literal_true() {
1195 let mut estimator = CardinalityEstimator::new();
1196 estimator.add_table_stats("Person", TableStats::new(1000));
1197
1198 let filter = LogicalOperator::Filter(FilterOp {
1199 predicate: LogicalExpression::Literal(Value::Bool(true)),
1200 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1201 variable: "n".to_string(),
1202 label: Some("Person".to_string()),
1203 input: None,
1204 })),
1205 });
1206
1207 let cardinality = estimator.estimate(&filter);
1208 assert!((cardinality - 1000.0).abs() < 0.001);
1210 }
1211
1212 #[test]
1213 fn test_filter_literal_false() {
1214 let mut estimator = CardinalityEstimator::new();
1215 estimator.add_table_stats("Person", TableStats::new(1000));
1216
1217 let filter = LogicalOperator::Filter(FilterOp {
1218 predicate: LogicalExpression::Literal(Value::Bool(false)),
1219 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1220 variable: "n".to_string(),
1221 label: Some("Person".to_string()),
1222 input: None,
1223 })),
1224 });
1225
1226 let cardinality = estimator.estimate(&filter);
1227 assert!((cardinality - 1.0).abs() < 0.001);
1229 }
1230
1231 #[test]
1232 fn test_unary_not_selectivity() {
1233 let mut estimator = CardinalityEstimator::new();
1234 estimator.add_table_stats("Person", TableStats::new(1000));
1235
1236 let filter = LogicalOperator::Filter(FilterOp {
1237 predicate: LogicalExpression::Unary {
1238 op: UnaryOp::Not,
1239 operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1240 },
1241 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1242 variable: "n".to_string(),
1243 label: Some("Person".to_string()),
1244 input: None,
1245 })),
1246 });
1247
1248 let cardinality = estimator.estimate(&filter);
1249 assert!(cardinality < 1000.0);
1251 }
1252
1253 #[test]
1254 fn test_unary_is_null_selectivity() {
1255 let mut estimator = CardinalityEstimator::new();
1256 estimator.add_table_stats("Person", TableStats::new(1000));
1257
1258 let filter = LogicalOperator::Filter(FilterOp {
1259 predicate: LogicalExpression::Unary {
1260 op: UnaryOp::IsNull,
1261 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1262 },
1263 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1264 variable: "n".to_string(),
1265 label: Some("Person".to_string()),
1266 input: None,
1267 })),
1268 });
1269
1270 let cardinality = estimator.estimate(&filter);
1271 assert!(cardinality < 100.0);
1273 }
1274
1275 #[test]
1276 fn test_expand_cardinality() {
1277 let mut estimator = CardinalityEstimator::new();
1278 estimator.add_table_stats("Person", TableStats::new(100));
1279
1280 let expand = LogicalOperator::Expand(ExpandOp {
1281 from_variable: "a".to_string(),
1282 to_variable: "b".to_string(),
1283 edge_variable: None,
1284 direction: ExpandDirection::Outgoing,
1285 edge_type: None,
1286 min_hops: 1,
1287 max_hops: Some(1),
1288 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1289 variable: "a".to_string(),
1290 label: Some("Person".to_string()),
1291 input: None,
1292 })),
1293 path_alias: None,
1294 });
1295
1296 let cardinality = estimator.estimate(&expand);
1297 assert!(cardinality > 100.0);
1299 }
1300
1301 #[test]
1302 fn test_expand_with_edge_type_filter() {
1303 let mut estimator = CardinalityEstimator::new();
1304 estimator.add_table_stats("Person", TableStats::new(100));
1305
1306 let expand = LogicalOperator::Expand(ExpandOp {
1307 from_variable: "a".to_string(),
1308 to_variable: "b".to_string(),
1309 edge_variable: None,
1310 direction: ExpandDirection::Outgoing,
1311 edge_type: Some("KNOWS".to_string()),
1312 min_hops: 1,
1313 max_hops: Some(1),
1314 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1315 variable: "a".to_string(),
1316 label: Some("Person".to_string()),
1317 input: None,
1318 })),
1319 path_alias: None,
1320 });
1321
1322 let cardinality = estimator.estimate(&expand);
1323 assert!(cardinality > 100.0);
1325 }
1326
1327 #[test]
1328 fn test_expand_variable_length() {
1329 let mut estimator = CardinalityEstimator::new();
1330 estimator.add_table_stats("Person", TableStats::new(100));
1331
1332 let expand = LogicalOperator::Expand(ExpandOp {
1333 from_variable: "a".to_string(),
1334 to_variable: "b".to_string(),
1335 edge_variable: None,
1336 direction: ExpandDirection::Outgoing,
1337 edge_type: None,
1338 min_hops: 1,
1339 max_hops: Some(3),
1340 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1341 variable: "a".to_string(),
1342 label: Some("Person".to_string()),
1343 input: None,
1344 })),
1345 path_alias: None,
1346 });
1347
1348 let cardinality = estimator.estimate(&expand);
1349 assert!(cardinality > 500.0);
1351 }
1352
1353 #[test]
1354 fn test_join_cross_product() {
1355 let mut estimator = CardinalityEstimator::new();
1356 estimator.add_table_stats("Person", TableStats::new(100));
1357 estimator.add_table_stats("Company", TableStats::new(50));
1358
1359 let join = LogicalOperator::Join(JoinOp {
1360 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1361 variable: "p".to_string(),
1362 label: Some("Person".to_string()),
1363 input: None,
1364 })),
1365 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1366 variable: "c".to_string(),
1367 label: Some("Company".to_string()),
1368 input: None,
1369 })),
1370 join_type: JoinType::Cross,
1371 conditions: vec![],
1372 });
1373
1374 let cardinality = estimator.estimate(&join);
1375 assert!((cardinality - 5000.0).abs() < 0.001);
1377 }
1378
1379 #[test]
1380 fn test_join_left_outer() {
1381 let mut estimator = CardinalityEstimator::new();
1382 estimator.add_table_stats("Person", TableStats::new(1000));
1383 estimator.add_table_stats("Company", TableStats::new(10));
1384
1385 let join = LogicalOperator::Join(JoinOp {
1386 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1387 variable: "p".to_string(),
1388 label: Some("Person".to_string()),
1389 input: None,
1390 })),
1391 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1392 variable: "c".to_string(),
1393 label: Some("Company".to_string()),
1394 input: None,
1395 })),
1396 join_type: JoinType::Left,
1397 conditions: vec![JoinCondition {
1398 left: LogicalExpression::Variable("p".to_string()),
1399 right: LogicalExpression::Variable("c".to_string()),
1400 }],
1401 });
1402
1403 let cardinality = estimator.estimate(&join);
1404 assert!(cardinality >= 1000.0);
1406 }
1407
1408 #[test]
1409 fn test_join_semi() {
1410 let mut estimator = CardinalityEstimator::new();
1411 estimator.add_table_stats("Person", TableStats::new(1000));
1412 estimator.add_table_stats("Company", TableStats::new(100));
1413
1414 let join = LogicalOperator::Join(JoinOp {
1415 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1416 variable: "p".to_string(),
1417 label: Some("Person".to_string()),
1418 input: None,
1419 })),
1420 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1421 variable: "c".to_string(),
1422 label: Some("Company".to_string()),
1423 input: None,
1424 })),
1425 join_type: JoinType::Semi,
1426 conditions: vec![],
1427 });
1428
1429 let cardinality = estimator.estimate(&join);
1430 assert!(cardinality <= 1000.0);
1432 }
1433
1434 #[test]
1435 fn test_join_anti() {
1436 let mut estimator = CardinalityEstimator::new();
1437 estimator.add_table_stats("Person", TableStats::new(1000));
1438 estimator.add_table_stats("Company", TableStats::new(100));
1439
1440 let join = LogicalOperator::Join(JoinOp {
1441 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1442 variable: "p".to_string(),
1443 label: Some("Person".to_string()),
1444 input: None,
1445 })),
1446 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1447 variable: "c".to_string(),
1448 label: Some("Company".to_string()),
1449 input: None,
1450 })),
1451 join_type: JoinType::Anti,
1452 conditions: vec![],
1453 });
1454
1455 let cardinality = estimator.estimate(&join);
1456 assert!(cardinality <= 1000.0);
1458 assert!(cardinality >= 1.0);
1459 }
1460
1461 #[test]
1462 fn test_project_preserves_cardinality() {
1463 let mut estimator = CardinalityEstimator::new();
1464 estimator.add_table_stats("Person", TableStats::new(1000));
1465
1466 let project = LogicalOperator::Project(ProjectOp {
1467 projections: vec![Projection {
1468 expression: LogicalExpression::Variable("n".to_string()),
1469 alias: None,
1470 }],
1471 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1472 variable: "n".to_string(),
1473 label: Some("Person".to_string()),
1474 input: None,
1475 })),
1476 });
1477
1478 let cardinality = estimator.estimate(&project);
1479 assert!((cardinality - 1000.0).abs() < 0.001);
1480 }
1481
1482 #[test]
1483 fn test_sort_preserves_cardinality() {
1484 let mut estimator = CardinalityEstimator::new();
1485 estimator.add_table_stats("Person", TableStats::new(1000));
1486
1487 let sort = LogicalOperator::Sort(SortOp {
1488 keys: vec![SortKey {
1489 expression: LogicalExpression::Variable("n".to_string()),
1490 order: SortOrder::Ascending,
1491 }],
1492 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1493 variable: "n".to_string(),
1494 label: Some("Person".to_string()),
1495 input: None,
1496 })),
1497 });
1498
1499 let cardinality = estimator.estimate(&sort);
1500 assert!((cardinality - 1000.0).abs() < 0.001);
1501 }
1502
1503 #[test]
1504 fn test_distinct_reduces_cardinality() {
1505 let mut estimator = CardinalityEstimator::new();
1506 estimator.add_table_stats("Person", TableStats::new(1000));
1507
1508 let distinct = LogicalOperator::Distinct(DistinctOp {
1509 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1510 variable: "n".to_string(),
1511 label: Some("Person".to_string()),
1512 input: None,
1513 })),
1514 columns: None,
1515 });
1516
1517 let cardinality = estimator.estimate(&distinct);
1518 assert!((cardinality - 500.0).abs() < 0.001);
1520 }
1521
1522 #[test]
1523 fn test_skip_reduces_cardinality() {
1524 let mut estimator = CardinalityEstimator::new();
1525 estimator.add_table_stats("Person", TableStats::new(1000));
1526
1527 let skip = LogicalOperator::Skip(SkipOp {
1528 count: 100,
1529 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1530 variable: "n".to_string(),
1531 label: Some("Person".to_string()),
1532 input: None,
1533 })),
1534 });
1535
1536 let cardinality = estimator.estimate(&skip);
1537 assert!((cardinality - 900.0).abs() < 0.001);
1538 }
1539
1540 #[test]
1541 fn test_return_preserves_cardinality() {
1542 let mut estimator = CardinalityEstimator::new();
1543 estimator.add_table_stats("Person", TableStats::new(1000));
1544
1545 let ret = LogicalOperator::Return(ReturnOp {
1546 items: vec![ReturnItem {
1547 expression: LogicalExpression::Variable("n".to_string()),
1548 alias: None,
1549 }],
1550 distinct: false,
1551 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1552 variable: "n".to_string(),
1553 label: Some("Person".to_string()),
1554 input: None,
1555 })),
1556 });
1557
1558 let cardinality = estimator.estimate(&ret);
1559 assert!((cardinality - 1000.0).abs() < 0.001);
1560 }
1561
1562 #[test]
1563 fn test_empty_cardinality() {
1564 let estimator = CardinalityEstimator::new();
1565 let cardinality = estimator.estimate(&LogicalOperator::Empty);
1566 assert!((cardinality).abs() < 0.001);
1567 }
1568
1569 #[test]
1570 fn test_table_stats_with_column() {
1571 let stats = TableStats::new(1000).with_column(
1572 "age",
1573 ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1574 );
1575
1576 assert_eq!(stats.row_count, 1000);
1577 let col = stats.columns.get("age").unwrap();
1578 assert_eq!(col.distinct_count, 50);
1579 assert_eq!(col.null_count, 10);
1580 assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1581 assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1582 }
1583
1584 #[test]
1585 fn test_estimator_default() {
1586 let estimator = CardinalityEstimator::default();
1587 let scan = LogicalOperator::NodeScan(NodeScanOp {
1588 variable: "n".to_string(),
1589 label: None,
1590 input: None,
1591 });
1592 let cardinality = estimator.estimate(&scan);
1593 assert!((cardinality - 1000.0).abs() < 0.001);
1594 }
1595
1596 #[test]
1597 fn test_set_avg_fanout() {
1598 let mut estimator = CardinalityEstimator::new();
1599 estimator.add_table_stats("Person", TableStats::new(100));
1600 estimator.set_avg_fanout(5.0);
1601
1602 let expand = LogicalOperator::Expand(ExpandOp {
1603 from_variable: "a".to_string(),
1604 to_variable: "b".to_string(),
1605 edge_variable: None,
1606 direction: ExpandDirection::Outgoing,
1607 edge_type: None,
1608 min_hops: 1,
1609 max_hops: Some(1),
1610 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1611 variable: "a".to_string(),
1612 label: Some("Person".to_string()),
1613 input: None,
1614 })),
1615 path_alias: None,
1616 });
1617
1618 let cardinality = estimator.estimate(&expand);
1619 assert!((cardinality - 500.0).abs() < 0.001);
1621 }
1622
1623 #[test]
1624 fn test_multiple_group_by_keys_reduce_cardinality() {
1625 let mut estimator = CardinalityEstimator::new();
1629 estimator.add_table_stats("Person", TableStats::new(10000));
1630
1631 let single_group = LogicalOperator::Aggregate(AggregateOp {
1632 group_by: vec![LogicalExpression::Property {
1633 variable: "n".to_string(),
1634 property: "city".to_string(),
1635 }],
1636 aggregates: vec![],
1637 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1638 variable: "n".to_string(),
1639 label: Some("Person".to_string()),
1640 input: None,
1641 })),
1642 having: None,
1643 });
1644
1645 let multi_group = LogicalOperator::Aggregate(AggregateOp {
1646 group_by: vec![
1647 LogicalExpression::Property {
1648 variable: "n".to_string(),
1649 property: "city".to_string(),
1650 },
1651 LogicalExpression::Property {
1652 variable: "n".to_string(),
1653 property: "country".to_string(),
1654 },
1655 ],
1656 aggregates: vec![],
1657 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1658 variable: "n".to_string(),
1659 label: Some("Person".to_string()),
1660 input: None,
1661 })),
1662 having: None,
1663 });
1664
1665 let single_card = estimator.estimate(&single_group);
1666 let multi_card = estimator.estimate(&multi_group);
1667
1668 assert!(single_card < 10000.0);
1670 assert!(multi_card < 10000.0);
1671 assert!(single_card >= 1.0);
1673 assert!(multi_card >= 1.0);
1674 }
1675
1676 #[test]
1679 fn test_histogram_build_uniform() {
1680 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1682 let histogram = EquiDepthHistogram::build(&values, 10);
1683
1684 assert_eq!(histogram.num_buckets(), 10);
1685 assert_eq!(histogram.total_rows(), 100);
1686
1687 for bucket in histogram.buckets() {
1689 assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1690 }
1691 }
1692
1693 #[test]
1694 fn test_histogram_build_skewed() {
1695 let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1697 values.extend((0..20).map(|i| 1000.0 + i as f64));
1698 let histogram = EquiDepthHistogram::build(&values, 5);
1699
1700 assert_eq!(histogram.num_buckets(), 5);
1701 assert_eq!(histogram.total_rows(), 100);
1702
1703 for bucket in histogram.buckets() {
1705 assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1706 }
1707 }
1708
1709 #[test]
1710 fn test_histogram_range_selectivity_full() {
1711 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1712 let histogram = EquiDepthHistogram::build(&values, 10);
1713
1714 let selectivity = histogram.range_selectivity(None, None);
1716 assert!((selectivity - 1.0).abs() < 0.01);
1717 }
1718
1719 #[test]
1720 fn test_histogram_range_selectivity_half() {
1721 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1722 let histogram = EquiDepthHistogram::build(&values, 10);
1723
1724 let selectivity = histogram.range_selectivity(Some(50.0), None);
1726 assert!(selectivity > 0.4 && selectivity < 0.6);
1727 }
1728
1729 #[test]
1730 fn test_histogram_range_selectivity_quarter() {
1731 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1732 let histogram = EquiDepthHistogram::build(&values, 10);
1733
1734 let selectivity = histogram.range_selectivity(None, Some(25.0));
1736 assert!(selectivity > 0.2 && selectivity < 0.3);
1737 }
1738
1739 #[test]
1740 fn test_histogram_equality_selectivity() {
1741 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1742 let histogram = EquiDepthHistogram::build(&values, 10);
1743
1744 let selectivity = histogram.equality_selectivity(50.0);
1746 assert!(selectivity > 0.005 && selectivity < 0.02);
1747 }
1748
1749 #[test]
1750 fn test_histogram_empty() {
1751 let histogram = EquiDepthHistogram::build(&[], 10);
1752
1753 assert_eq!(histogram.num_buckets(), 0);
1754 assert_eq!(histogram.total_rows(), 0);
1755
1756 let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1758 assert!((selectivity - 0.33).abs() < 0.01);
1759 }
1760
1761 #[test]
1762 fn test_histogram_bucket_overlap() {
1763 let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1764
1765 assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1767
1768 assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1770
1771 assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1773
1774 assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1776
1777 assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1779 }
1780
1781 #[test]
1782 fn test_column_stats_from_values() {
1783 let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
1784 let stats = ColumnStats::from_values(values, 4);
1785
1786 assert_eq!(stats.distinct_count, 5); assert!(stats.min_value.is_some());
1788 assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
1789 assert!(stats.max_value.is_some());
1790 assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
1791 assert!(stats.histogram.is_some());
1792 }
1793
1794 #[test]
1795 fn test_column_stats_with_histogram_builder() {
1796 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1797 let histogram = EquiDepthHistogram::build(&values, 10);
1798
1799 let stats = ColumnStats::new(100)
1800 .with_range(0.0, 99.0)
1801 .with_histogram(histogram);
1802
1803 assert!(stats.histogram.is_some());
1804 assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
1805 }
1806
1807 #[test]
1808 fn test_filter_with_histogram_stats() {
1809 let mut estimator = CardinalityEstimator::new();
1810
1811 let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
1813 let histogram = EquiDepthHistogram::build(&age_values, 10);
1814 let age_stats = ColumnStats::new(62)
1815 .with_range(18.0, 79.0)
1816 .with_histogram(histogram);
1817
1818 estimator.add_table_stats(
1819 "Person",
1820 TableStats::new(1000).with_column("age", age_stats),
1821 );
1822
1823 let filter = LogicalOperator::Filter(FilterOp {
1826 predicate: LogicalExpression::Binary {
1827 left: Box::new(LogicalExpression::Property {
1828 variable: "n".to_string(),
1829 property: "age".to_string(),
1830 }),
1831 op: BinaryOp::Gt,
1832 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1833 },
1834 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1835 variable: "n".to_string(),
1836 label: Some("Person".to_string()),
1837 input: None,
1838 })),
1839 });
1840
1841 let cardinality = estimator.estimate(&filter);
1842
1843 assert!(cardinality > 300.0 && cardinality < 600.0);
1846 }
1847
1848 #[test]
1849 fn test_filter_equality_with_histogram() {
1850 let mut estimator = CardinalityEstimator::new();
1851
1852 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1854 let histogram = EquiDepthHistogram::build(&values, 10);
1855 let stats = ColumnStats::new(100)
1856 .with_range(0.0, 99.0)
1857 .with_histogram(histogram);
1858
1859 estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
1860
1861 let filter = LogicalOperator::Filter(FilterOp {
1863 predicate: LogicalExpression::Binary {
1864 left: Box::new(LogicalExpression::Property {
1865 variable: "d".to_string(),
1866 property: "value".to_string(),
1867 }),
1868 op: BinaryOp::Eq,
1869 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1870 },
1871 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1872 variable: "d".to_string(),
1873 label: Some("Data".to_string()),
1874 input: None,
1875 })),
1876 });
1877
1878 let cardinality = estimator.estimate(&filter);
1879
1880 assert!(cardinality >= 1.0 && cardinality < 50.0);
1883 }
1884
1885 #[test]
1886 fn test_histogram_min_max() {
1887 let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
1888 let histogram = EquiDepthHistogram::build(&values, 2);
1889
1890 assert_eq!(histogram.min_value(), Some(5.0));
1891 assert!(histogram.max_value().is_some());
1893 }
1894}