1use crate::query::plan::{
6 AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
7 LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
8};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct TableStats {
14 pub row_count: u64,
16 pub columns: HashMap<String, ColumnStats>,
18}
19
20impl TableStats {
21 #[must_use]
23 pub fn new(row_count: u64) -> Self {
24 Self {
25 row_count,
26 columns: HashMap::new(),
27 }
28 }
29
30 pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
32 self.columns.insert(name.to_string(), stats);
33 self
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct ColumnStats {
40 pub distinct_count: u64,
42 pub null_count: u64,
44 pub min_value: Option<f64>,
46 pub max_value: Option<f64>,
48}
49
50impl ColumnStats {
51 #[must_use]
53 pub fn new(distinct_count: u64) -> Self {
54 Self {
55 distinct_count,
56 null_count: 0,
57 min_value: None,
58 max_value: None,
59 }
60 }
61
62 #[must_use]
64 pub fn with_nulls(mut self, null_count: u64) -> Self {
65 self.null_count = null_count;
66 self
67 }
68
69 #[must_use]
71 pub fn with_range(mut self, min: f64, max: f64) -> Self {
72 self.min_value = Some(min);
73 self.max_value = Some(max);
74 self
75 }
76}
77
78pub struct CardinalityEstimator {
80 table_stats: HashMap<String, TableStats>,
82 default_row_count: u64,
84 default_selectivity: f64,
86 avg_fanout: f64,
88}
89
90impl CardinalityEstimator {
91 #[must_use]
93 pub fn new() -> Self {
94 Self {
95 table_stats: HashMap::new(),
96 default_row_count: 1000,
97 default_selectivity: 0.1,
98 avg_fanout: 10.0,
99 }
100 }
101
102 pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
104 self.table_stats.insert(name.to_string(), stats);
105 }
106
107 pub fn set_avg_fanout(&mut self, fanout: f64) {
109 self.avg_fanout = fanout;
110 }
111
112 #[must_use]
114 pub fn estimate(&self, op: &LogicalOperator) -> f64 {
115 match op {
116 LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
117 LogicalOperator::Filter(filter) => self.estimate_filter(filter),
118 LogicalOperator::Project(project) => self.estimate_project(project),
119 LogicalOperator::Expand(expand) => self.estimate_expand(expand),
120 LogicalOperator::Join(join) => self.estimate_join(join),
121 LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
122 LogicalOperator::Sort(sort) => self.estimate_sort(sort),
123 LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
124 LogicalOperator::Limit(limit) => self.estimate_limit(limit),
125 LogicalOperator::Skip(skip) => self.estimate_skip(skip),
126 LogicalOperator::Return(ret) => self.estimate(&ret.input),
127 LogicalOperator::Empty => 0.0,
128 _ => self.default_row_count as f64,
129 }
130 }
131
132 fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
134 if let Some(label) = &scan.label {
135 if let Some(stats) = self.table_stats.get(label) {
136 return stats.row_count as f64;
137 }
138 }
139 self.default_row_count as f64
141 }
142
143 fn estimate_filter(&self, filter: &FilterOp) -> f64 {
145 let input_cardinality = self.estimate(&filter.input);
146 let selectivity = self.estimate_selectivity(&filter.predicate);
147 (input_cardinality * selectivity).max(1.0)
148 }
149
150 fn estimate_project(&self, project: &ProjectOp) -> f64 {
152 self.estimate(&project.input)
153 }
154
155 fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
157 let input_cardinality = self.estimate(&expand.input);
158
159 let fanout = if expand.edge_type.is_some() {
161 self.avg_fanout * 0.5
163 } else {
164 self.avg_fanout
165 };
166
167 let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
169 let min = expand.min_hops as f64;
170 let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
171 (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
173 } else {
174 fanout
175 };
176
177 (input_cardinality * path_multiplier).max(1.0)
178 }
179
180 fn estimate_join(&self, join: &JoinOp) -> f64 {
182 let left_card = self.estimate(&join.left);
183 let right_card = self.estimate(&join.right);
184
185 match join.join_type {
186 JoinType::Cross => left_card * right_card,
187 JoinType::Inner => {
188 let selectivity = if join.conditions.is_empty() {
190 1.0 } else {
192 0.1_f64.powi(join.conditions.len() as i32)
194 };
195 (left_card * right_card * selectivity).max(1.0)
196 }
197 JoinType::Left => {
198 let inner_card = self.estimate_join(&JoinOp {
200 left: join.left.clone(),
201 right: join.right.clone(),
202 join_type: JoinType::Inner,
203 conditions: join.conditions.clone(),
204 });
205 inner_card.max(left_card)
206 }
207 JoinType::Right => {
208 let inner_card = self.estimate_join(&JoinOp {
210 left: join.left.clone(),
211 right: join.right.clone(),
212 join_type: JoinType::Inner,
213 conditions: join.conditions.clone(),
214 });
215 inner_card.max(right_card)
216 }
217 JoinType::Full => {
218 let inner_card = self.estimate_join(&JoinOp {
220 left: join.left.clone(),
221 right: join.right.clone(),
222 join_type: JoinType::Inner,
223 conditions: join.conditions.clone(),
224 });
225 inner_card.max(left_card.max(right_card))
226 }
227 JoinType::Semi => {
228 (left_card * self.default_selectivity).max(1.0)
230 }
231 JoinType::Anti => {
232 (left_card * (1.0 - self.default_selectivity)).max(1.0)
234 }
235 }
236 }
237
238 fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
240 let input_cardinality = self.estimate(&agg.input);
241
242 if agg.group_by.is_empty() {
243 1.0
245 } else {
246 let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
249 (input_cardinality / group_reduction).max(1.0)
250 }
251 }
252
253 fn estimate_sort(&self, sort: &SortOp) -> f64 {
255 self.estimate(&sort.input)
256 }
257
258 fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
260 let input_cardinality = self.estimate(&distinct.input);
261 (input_cardinality * 0.5).max(1.0)
263 }
264
265 fn estimate_limit(&self, limit: &LimitOp) -> f64 {
267 let input_cardinality = self.estimate(&limit.input);
268 (limit.count as f64).min(input_cardinality)
269 }
270
271 fn estimate_skip(&self, skip: &SkipOp) -> f64 {
273 let input_cardinality = self.estimate(&skip.input);
274 (input_cardinality - skip.count as f64).max(0.0)
275 }
276
277 fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
279 match expr {
280 LogicalExpression::Binary { left, op, right } => {
281 self.estimate_binary_selectivity(left, *op, right)
282 }
283 LogicalExpression::Unary { op, operand } => {
284 self.estimate_unary_selectivity(*op, operand)
285 }
286 LogicalExpression::Literal(value) => {
287 if let grafeo_common::types::Value::Bool(b) = value {
289 if *b { 1.0 } else { 0.0 }
290 } else {
291 self.default_selectivity
292 }
293 }
294 _ => self.default_selectivity,
295 }
296 }
297
298 fn estimate_binary_selectivity(
300 &self,
301 _left: &LogicalExpression,
302 op: BinaryOp,
303 _right: &LogicalExpression,
304 ) -> f64 {
305 match op {
306 BinaryOp::Eq => 0.01,
308 BinaryOp::Ne => 0.99,
310 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => 0.33,
312 BinaryOp::And => {
314 self.default_selectivity * self.default_selectivity
316 }
317 BinaryOp::Or => {
318 1.0 - (1.0 - self.default_selectivity) * (1.0 - self.default_selectivity)
320 }
321 BinaryOp::StartsWith => 0.1,
323 BinaryOp::EndsWith => 0.1,
324 BinaryOp::Contains => 0.1,
325 BinaryOp::Like => 0.1,
326 BinaryOp::In => 0.1,
328 _ => self.default_selectivity,
330 }
331 }
332
333 fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
335 match op {
336 UnaryOp::Not => 1.0 - self.default_selectivity,
337 UnaryOp::IsNull => 0.05, UnaryOp::IsNotNull => 0.95,
339 UnaryOp::Neg => 1.0, }
341 }
342
343 fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
345 self.table_stats.get(label)?.columns.get(column)
346 }
347
348 #[allow(dead_code)]
350 fn estimate_equality_with_stats(&self, label: &str, column: &str) -> f64 {
351 if let Some(stats) = self.get_column_stats(label, column) {
352 if stats.distinct_count > 0 {
353 return 1.0 / stats.distinct_count as f64;
354 }
355 }
356 0.01 }
358
359 #[allow(dead_code)]
361 fn estimate_range_with_stats(
362 &self,
363 label: &str,
364 column: &str,
365 lower: Option<f64>,
366 upper: Option<f64>,
367 ) -> f64 {
368 if let Some(stats) = self.get_column_stats(label, column) {
369 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
370 let range = max - min;
371 if range <= 0.0 {
372 return 1.0;
373 }
374
375 let effective_lower = lower.unwrap_or(min).max(min);
376 let effective_upper = upper.unwrap_or(max).min(max);
377
378 let overlap = (effective_upper - effective_lower).max(0.0);
379 return (overlap / range).min(1.0).max(0.0);
380 }
381 }
382 0.33 }
384}
385
386impl Default for CardinalityEstimator {
387 fn default() -> Self {
388 Self::new()
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use crate::query::plan::{
396 DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, ProjectOp,
397 Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
398 };
399 use grafeo_common::types::Value;
400
401 #[test]
402 fn test_node_scan_with_stats() {
403 let mut estimator = CardinalityEstimator::new();
404 estimator.add_table_stats("Person", TableStats::new(5000));
405
406 let scan = LogicalOperator::NodeScan(NodeScanOp {
407 variable: "n".to_string(),
408 label: Some("Person".to_string()),
409 input: None,
410 });
411
412 let cardinality = estimator.estimate(&scan);
413 assert!((cardinality - 5000.0).abs() < 0.001);
414 }
415
416 #[test]
417 fn test_filter_reduces_cardinality() {
418 let mut estimator = CardinalityEstimator::new();
419 estimator.add_table_stats("Person", TableStats::new(1000));
420
421 let filter = LogicalOperator::Filter(FilterOp {
422 predicate: LogicalExpression::Binary {
423 left: Box::new(LogicalExpression::Property {
424 variable: "n".to_string(),
425 property: "age".to_string(),
426 }),
427 op: BinaryOp::Eq,
428 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
429 },
430 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
431 variable: "n".to_string(),
432 label: Some("Person".to_string()),
433 input: None,
434 })),
435 });
436
437 let cardinality = estimator.estimate(&filter);
438 assert!(cardinality < 1000.0);
440 assert!(cardinality >= 1.0);
441 }
442
443 #[test]
444 fn test_join_cardinality() {
445 let mut estimator = CardinalityEstimator::new();
446 estimator.add_table_stats("Person", TableStats::new(1000));
447 estimator.add_table_stats("Company", TableStats::new(100));
448
449 let join = LogicalOperator::Join(JoinOp {
450 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
451 variable: "p".to_string(),
452 label: Some("Person".to_string()),
453 input: None,
454 })),
455 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
456 variable: "c".to_string(),
457 label: Some("Company".to_string()),
458 input: None,
459 })),
460 join_type: JoinType::Inner,
461 conditions: vec![JoinCondition {
462 left: LogicalExpression::Property {
463 variable: "p".to_string(),
464 property: "company_id".to_string(),
465 },
466 right: LogicalExpression::Property {
467 variable: "c".to_string(),
468 property: "id".to_string(),
469 },
470 }],
471 });
472
473 let cardinality = estimator.estimate(&join);
474 assert!(cardinality < 1000.0 * 100.0);
476 }
477
478 #[test]
479 fn test_limit_caps_cardinality() {
480 let mut estimator = CardinalityEstimator::new();
481 estimator.add_table_stats("Person", TableStats::new(1000));
482
483 let limit = LogicalOperator::Limit(LimitOp {
484 count: 10,
485 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
486 variable: "n".to_string(),
487 label: Some("Person".to_string()),
488 input: None,
489 })),
490 });
491
492 let cardinality = estimator.estimate(&limit);
493 assert!((cardinality - 10.0).abs() < 0.001);
494 }
495
496 #[test]
497 fn test_aggregate_reduces_cardinality() {
498 let mut estimator = CardinalityEstimator::new();
499 estimator.add_table_stats("Person", TableStats::new(1000));
500
501 let global_agg = LogicalOperator::Aggregate(AggregateOp {
503 group_by: vec![],
504 aggregates: vec![],
505 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
506 variable: "n".to_string(),
507 label: Some("Person".to_string()),
508 input: None,
509 })),
510 });
511
512 let cardinality = estimator.estimate(&global_agg);
513 assert!((cardinality - 1.0).abs() < 0.001);
514
515 let group_agg = LogicalOperator::Aggregate(AggregateOp {
517 group_by: vec![LogicalExpression::Property {
518 variable: "n".to_string(),
519 property: "city".to_string(),
520 }],
521 aggregates: vec![],
522 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
523 variable: "n".to_string(),
524 label: Some("Person".to_string()),
525 input: None,
526 })),
527 });
528
529 let cardinality = estimator.estimate(&group_agg);
530 assert!(cardinality < 1000.0);
532 }
533
534 #[test]
535 fn test_node_scan_without_stats() {
536 let estimator = CardinalityEstimator::new();
537
538 let scan = LogicalOperator::NodeScan(NodeScanOp {
539 variable: "n".to_string(),
540 label: Some("Unknown".to_string()),
541 input: None,
542 });
543
544 let cardinality = estimator.estimate(&scan);
545 assert!((cardinality - 1000.0).abs() < 0.001);
547 }
548
549 #[test]
550 fn test_node_scan_no_label() {
551 let estimator = CardinalityEstimator::new();
552
553 let scan = LogicalOperator::NodeScan(NodeScanOp {
554 variable: "n".to_string(),
555 label: None,
556 input: None,
557 });
558
559 let cardinality = estimator.estimate(&scan);
560 assert!((cardinality - 1000.0).abs() < 0.001);
562 }
563
564 #[test]
565 fn test_filter_inequality_selectivity() {
566 let mut estimator = CardinalityEstimator::new();
567 estimator.add_table_stats("Person", TableStats::new(1000));
568
569 let filter = LogicalOperator::Filter(FilterOp {
570 predicate: LogicalExpression::Binary {
571 left: Box::new(LogicalExpression::Property {
572 variable: "n".to_string(),
573 property: "age".to_string(),
574 }),
575 op: BinaryOp::Ne,
576 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
577 },
578 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
579 variable: "n".to_string(),
580 label: Some("Person".to_string()),
581 input: None,
582 })),
583 });
584
585 let cardinality = estimator.estimate(&filter);
586 assert!(cardinality > 900.0);
588 }
589
590 #[test]
591 fn test_filter_range_selectivity() {
592 let mut estimator = CardinalityEstimator::new();
593 estimator.add_table_stats("Person", TableStats::new(1000));
594
595 let filter = LogicalOperator::Filter(FilterOp {
596 predicate: LogicalExpression::Binary {
597 left: Box::new(LogicalExpression::Property {
598 variable: "n".to_string(),
599 property: "age".to_string(),
600 }),
601 op: BinaryOp::Gt,
602 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
603 },
604 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
605 variable: "n".to_string(),
606 label: Some("Person".to_string()),
607 input: None,
608 })),
609 });
610
611 let cardinality = estimator.estimate(&filter);
612 assert!(cardinality < 500.0);
614 assert!(cardinality > 100.0);
615 }
616
617 #[test]
618 fn test_filter_and_selectivity() {
619 let mut estimator = CardinalityEstimator::new();
620 estimator.add_table_stats("Person", TableStats::new(1000));
621
622 let filter = LogicalOperator::Filter(FilterOp {
623 predicate: LogicalExpression::Binary {
624 left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
625 op: BinaryOp::And,
626 right: Box::new(LogicalExpression::Literal(Value::Bool(true))),
627 },
628 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
629 variable: "n".to_string(),
630 label: Some("Person".to_string()),
631 input: None,
632 })),
633 });
634
635 let cardinality = estimator.estimate(&filter);
636 assert!(cardinality < 1000.0);
638 }
639
640 #[test]
641 fn test_filter_or_selectivity() {
642 let mut estimator = CardinalityEstimator::new();
643 estimator.add_table_stats("Person", TableStats::new(1000));
644
645 let filter = LogicalOperator::Filter(FilterOp {
646 predicate: LogicalExpression::Binary {
647 left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
648 op: BinaryOp::Or,
649 right: Box::new(LogicalExpression::Literal(Value::Bool(true))),
650 },
651 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
652 variable: "n".to_string(),
653 label: Some("Person".to_string()),
654 input: None,
655 })),
656 });
657
658 let cardinality = estimator.estimate(&filter);
659 assert!(cardinality < 1000.0);
661 }
662
663 #[test]
664 fn test_filter_literal_true() {
665 let mut estimator = CardinalityEstimator::new();
666 estimator.add_table_stats("Person", TableStats::new(1000));
667
668 let filter = LogicalOperator::Filter(FilterOp {
669 predicate: LogicalExpression::Literal(Value::Bool(true)),
670 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
671 variable: "n".to_string(),
672 label: Some("Person".to_string()),
673 input: None,
674 })),
675 });
676
677 let cardinality = estimator.estimate(&filter);
678 assert!((cardinality - 1000.0).abs() < 0.001);
680 }
681
682 #[test]
683 fn test_filter_literal_false() {
684 let mut estimator = CardinalityEstimator::new();
685 estimator.add_table_stats("Person", TableStats::new(1000));
686
687 let filter = LogicalOperator::Filter(FilterOp {
688 predicate: LogicalExpression::Literal(Value::Bool(false)),
689 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
690 variable: "n".to_string(),
691 label: Some("Person".to_string()),
692 input: None,
693 })),
694 });
695
696 let cardinality = estimator.estimate(&filter);
697 assert!((cardinality - 1.0).abs() < 0.001);
699 }
700
701 #[test]
702 fn test_unary_not_selectivity() {
703 let mut estimator = CardinalityEstimator::new();
704 estimator.add_table_stats("Person", TableStats::new(1000));
705
706 let filter = LogicalOperator::Filter(FilterOp {
707 predicate: LogicalExpression::Unary {
708 op: UnaryOp::Not,
709 operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
710 },
711 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
712 variable: "n".to_string(),
713 label: Some("Person".to_string()),
714 input: None,
715 })),
716 });
717
718 let cardinality = estimator.estimate(&filter);
719 assert!(cardinality < 1000.0);
721 }
722
723 #[test]
724 fn test_unary_is_null_selectivity() {
725 let mut estimator = CardinalityEstimator::new();
726 estimator.add_table_stats("Person", TableStats::new(1000));
727
728 let filter = LogicalOperator::Filter(FilterOp {
729 predicate: LogicalExpression::Unary {
730 op: UnaryOp::IsNull,
731 operand: Box::new(LogicalExpression::Variable("x".to_string())),
732 },
733 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
734 variable: "n".to_string(),
735 label: Some("Person".to_string()),
736 input: None,
737 })),
738 });
739
740 let cardinality = estimator.estimate(&filter);
741 assert!(cardinality < 100.0);
743 }
744
745 #[test]
746 fn test_expand_cardinality() {
747 let mut estimator = CardinalityEstimator::new();
748 estimator.add_table_stats("Person", TableStats::new(100));
749
750 let expand = LogicalOperator::Expand(ExpandOp {
751 from_variable: "a".to_string(),
752 to_variable: "b".to_string(),
753 edge_variable: None,
754 direction: ExpandDirection::Outgoing,
755 edge_type: None,
756 min_hops: 1,
757 max_hops: Some(1),
758 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
759 variable: "a".to_string(),
760 label: Some("Person".to_string()),
761 input: None,
762 })),
763 });
764
765 let cardinality = estimator.estimate(&expand);
766 assert!(cardinality > 100.0);
768 }
769
770 #[test]
771 fn test_expand_with_edge_type_filter() {
772 let mut estimator = CardinalityEstimator::new();
773 estimator.add_table_stats("Person", TableStats::new(100));
774
775 let expand = LogicalOperator::Expand(ExpandOp {
776 from_variable: "a".to_string(),
777 to_variable: "b".to_string(),
778 edge_variable: None,
779 direction: ExpandDirection::Outgoing,
780 edge_type: Some("KNOWS".to_string()),
781 min_hops: 1,
782 max_hops: Some(1),
783 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
784 variable: "a".to_string(),
785 label: Some("Person".to_string()),
786 input: None,
787 })),
788 });
789
790 let cardinality = estimator.estimate(&expand);
791 assert!(cardinality > 100.0);
793 }
794
795 #[test]
796 fn test_expand_variable_length() {
797 let mut estimator = CardinalityEstimator::new();
798 estimator.add_table_stats("Person", TableStats::new(100));
799
800 let expand = LogicalOperator::Expand(ExpandOp {
801 from_variable: "a".to_string(),
802 to_variable: "b".to_string(),
803 edge_variable: None,
804 direction: ExpandDirection::Outgoing,
805 edge_type: None,
806 min_hops: 1,
807 max_hops: Some(3),
808 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
809 variable: "a".to_string(),
810 label: Some("Person".to_string()),
811 input: None,
812 })),
813 });
814
815 let cardinality = estimator.estimate(&expand);
816 assert!(cardinality > 500.0);
818 }
819
820 #[test]
821 fn test_join_cross_product() {
822 let mut estimator = CardinalityEstimator::new();
823 estimator.add_table_stats("Person", TableStats::new(100));
824 estimator.add_table_stats("Company", TableStats::new(50));
825
826 let join = LogicalOperator::Join(JoinOp {
827 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
828 variable: "p".to_string(),
829 label: Some("Person".to_string()),
830 input: None,
831 })),
832 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
833 variable: "c".to_string(),
834 label: Some("Company".to_string()),
835 input: None,
836 })),
837 join_type: JoinType::Cross,
838 conditions: vec![],
839 });
840
841 let cardinality = estimator.estimate(&join);
842 assert!((cardinality - 5000.0).abs() < 0.001);
844 }
845
846 #[test]
847 fn test_join_left_outer() {
848 let mut estimator = CardinalityEstimator::new();
849 estimator.add_table_stats("Person", TableStats::new(1000));
850 estimator.add_table_stats("Company", TableStats::new(10));
851
852 let join = LogicalOperator::Join(JoinOp {
853 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
854 variable: "p".to_string(),
855 label: Some("Person".to_string()),
856 input: None,
857 })),
858 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
859 variable: "c".to_string(),
860 label: Some("Company".to_string()),
861 input: None,
862 })),
863 join_type: JoinType::Left,
864 conditions: vec![JoinCondition {
865 left: LogicalExpression::Variable("p".to_string()),
866 right: LogicalExpression::Variable("c".to_string()),
867 }],
868 });
869
870 let cardinality = estimator.estimate(&join);
871 assert!(cardinality >= 1000.0);
873 }
874
875 #[test]
876 fn test_join_semi() {
877 let mut estimator = CardinalityEstimator::new();
878 estimator.add_table_stats("Person", TableStats::new(1000));
879 estimator.add_table_stats("Company", TableStats::new(100));
880
881 let join = LogicalOperator::Join(JoinOp {
882 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
883 variable: "p".to_string(),
884 label: Some("Person".to_string()),
885 input: None,
886 })),
887 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
888 variable: "c".to_string(),
889 label: Some("Company".to_string()),
890 input: None,
891 })),
892 join_type: JoinType::Semi,
893 conditions: vec![],
894 });
895
896 let cardinality = estimator.estimate(&join);
897 assert!(cardinality <= 1000.0);
899 }
900
901 #[test]
902 fn test_join_anti() {
903 let mut estimator = CardinalityEstimator::new();
904 estimator.add_table_stats("Person", TableStats::new(1000));
905 estimator.add_table_stats("Company", TableStats::new(100));
906
907 let join = LogicalOperator::Join(JoinOp {
908 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
909 variable: "p".to_string(),
910 label: Some("Person".to_string()),
911 input: None,
912 })),
913 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
914 variable: "c".to_string(),
915 label: Some("Company".to_string()),
916 input: None,
917 })),
918 join_type: JoinType::Anti,
919 conditions: vec![],
920 });
921
922 let cardinality = estimator.estimate(&join);
923 assert!(cardinality <= 1000.0);
925 assert!(cardinality >= 1.0);
926 }
927
928 #[test]
929 fn test_project_preserves_cardinality() {
930 let mut estimator = CardinalityEstimator::new();
931 estimator.add_table_stats("Person", TableStats::new(1000));
932
933 let project = LogicalOperator::Project(ProjectOp {
934 projections: vec![Projection {
935 expression: LogicalExpression::Variable("n".to_string()),
936 alias: None,
937 }],
938 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
939 variable: "n".to_string(),
940 label: Some("Person".to_string()),
941 input: None,
942 })),
943 });
944
945 let cardinality = estimator.estimate(&project);
946 assert!((cardinality - 1000.0).abs() < 0.001);
947 }
948
949 #[test]
950 fn test_sort_preserves_cardinality() {
951 let mut estimator = CardinalityEstimator::new();
952 estimator.add_table_stats("Person", TableStats::new(1000));
953
954 let sort = LogicalOperator::Sort(SortOp {
955 keys: vec![SortKey {
956 expression: LogicalExpression::Variable("n".to_string()),
957 order: SortOrder::Ascending,
958 }],
959 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
960 variable: "n".to_string(),
961 label: Some("Person".to_string()),
962 input: None,
963 })),
964 });
965
966 let cardinality = estimator.estimate(&sort);
967 assert!((cardinality - 1000.0).abs() < 0.001);
968 }
969
970 #[test]
971 fn test_distinct_reduces_cardinality() {
972 let mut estimator = CardinalityEstimator::new();
973 estimator.add_table_stats("Person", TableStats::new(1000));
974
975 let distinct = LogicalOperator::Distinct(DistinctOp {
976 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
977 variable: "n".to_string(),
978 label: Some("Person".to_string()),
979 input: None,
980 })),
981 });
982
983 let cardinality = estimator.estimate(&distinct);
984 assert!((cardinality - 500.0).abs() < 0.001);
986 }
987
988 #[test]
989 fn test_skip_reduces_cardinality() {
990 let mut estimator = CardinalityEstimator::new();
991 estimator.add_table_stats("Person", TableStats::new(1000));
992
993 let skip = LogicalOperator::Skip(SkipOp {
994 count: 100,
995 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
996 variable: "n".to_string(),
997 label: Some("Person".to_string()),
998 input: None,
999 })),
1000 });
1001
1002 let cardinality = estimator.estimate(&skip);
1003 assert!((cardinality - 900.0).abs() < 0.001);
1004 }
1005
1006 #[test]
1007 fn test_return_preserves_cardinality() {
1008 let mut estimator = CardinalityEstimator::new();
1009 estimator.add_table_stats("Person", TableStats::new(1000));
1010
1011 let ret = LogicalOperator::Return(ReturnOp {
1012 items: vec![ReturnItem {
1013 expression: LogicalExpression::Variable("n".to_string()),
1014 alias: None,
1015 }],
1016 distinct: false,
1017 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1018 variable: "n".to_string(),
1019 label: Some("Person".to_string()),
1020 input: None,
1021 })),
1022 });
1023
1024 let cardinality = estimator.estimate(&ret);
1025 assert!((cardinality - 1000.0).abs() < 0.001);
1026 }
1027
1028 #[test]
1029 fn test_empty_cardinality() {
1030 let estimator = CardinalityEstimator::new();
1031 let cardinality = estimator.estimate(&LogicalOperator::Empty);
1032 assert!((cardinality).abs() < 0.001);
1033 }
1034
1035 #[test]
1036 fn test_table_stats_with_column() {
1037 let stats = TableStats::new(1000).with_column(
1038 "age",
1039 ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1040 );
1041
1042 assert_eq!(stats.row_count, 1000);
1043 let col = stats.columns.get("age").unwrap();
1044 assert_eq!(col.distinct_count, 50);
1045 assert_eq!(col.null_count, 10);
1046 assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1047 assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1048 }
1049
1050 #[test]
1051 fn test_estimator_default() {
1052 let estimator = CardinalityEstimator::default();
1053 let scan = LogicalOperator::NodeScan(NodeScanOp {
1054 variable: "n".to_string(),
1055 label: None,
1056 input: None,
1057 });
1058 let cardinality = estimator.estimate(&scan);
1059 assert!((cardinality - 1000.0).abs() < 0.001);
1060 }
1061
1062 #[test]
1063 fn test_set_avg_fanout() {
1064 let mut estimator = CardinalityEstimator::new();
1065 estimator.add_table_stats("Person", TableStats::new(100));
1066 estimator.set_avg_fanout(5.0);
1067
1068 let expand = LogicalOperator::Expand(ExpandOp {
1069 from_variable: "a".to_string(),
1070 to_variable: "b".to_string(),
1071 edge_variable: None,
1072 direction: ExpandDirection::Outgoing,
1073 edge_type: None,
1074 min_hops: 1,
1075 max_hops: Some(1),
1076 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1077 variable: "a".to_string(),
1078 label: Some("Person".to_string()),
1079 input: None,
1080 })),
1081 });
1082
1083 let cardinality = estimator.estimate(&expand);
1084 assert!((cardinality - 500.0).abs() < 0.001);
1086 }
1087
1088 #[test]
1089 fn test_multiple_group_by_keys_reduce_cardinality() {
1090 let mut estimator = CardinalityEstimator::new();
1094 estimator.add_table_stats("Person", TableStats::new(10000));
1095
1096 let single_group = LogicalOperator::Aggregate(AggregateOp {
1097 group_by: vec![LogicalExpression::Property {
1098 variable: "n".to_string(),
1099 property: "city".to_string(),
1100 }],
1101 aggregates: vec![],
1102 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1103 variable: "n".to_string(),
1104 label: Some("Person".to_string()),
1105 input: None,
1106 })),
1107 });
1108
1109 let multi_group = LogicalOperator::Aggregate(AggregateOp {
1110 group_by: vec![
1111 LogicalExpression::Property {
1112 variable: "n".to_string(),
1113 property: "city".to_string(),
1114 },
1115 LogicalExpression::Property {
1116 variable: "n".to_string(),
1117 property: "country".to_string(),
1118 },
1119 ],
1120 aggregates: vec![],
1121 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1122 variable: "n".to_string(),
1123 label: Some("Person".to_string()),
1124 input: None,
1125 })),
1126 });
1127
1128 let single_card = estimator.estimate(&single_group);
1129 let multi_card = estimator.estimate(&multi_group);
1130
1131 assert!(single_card < 10000.0);
1133 assert!(multi_card < 10000.0);
1134 assert!(single_card >= 1.0);
1136 assert!(multi_card >= 1.0);
1137 }
1138}