Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4
5use crate::query::plan::{
6    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
7    LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
8};
9use std::collections::HashMap;
10
11/// Statistics for a table/label.
12#[derive(Debug, Clone)]
13pub struct TableStats {
14    /// Total number of rows.
15    pub row_count: u64,
16    /// Column statistics.
17    pub columns: HashMap<String, ColumnStats>,
18}
19
20impl TableStats {
21    /// Creates new table statistics.
22    #[must_use]
23    pub fn new(row_count: u64) -> Self {
24        Self {
25            row_count,
26            columns: HashMap::new(),
27        }
28    }
29
30    /// Adds column statistics.
31    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
32        self.columns.insert(name.to_string(), stats);
33        self
34    }
35}
36
37/// Statistics for a column.
38#[derive(Debug, Clone)]
39pub struct ColumnStats {
40    /// Number of distinct values.
41    pub distinct_count: u64,
42    /// Number of null values.
43    pub null_count: u64,
44    /// Minimum value (if orderable).
45    pub min_value: Option<f64>,
46    /// Maximum value (if orderable).
47    pub max_value: Option<f64>,
48}
49
50impl ColumnStats {
51    /// Creates new column statistics.
52    #[must_use]
53    pub fn new(distinct_count: u64) -> Self {
54        Self {
55            distinct_count,
56            null_count: 0,
57            min_value: None,
58            max_value: None,
59        }
60    }
61
62    /// Sets the null count.
63    #[must_use]
64    pub fn with_nulls(mut self, null_count: u64) -> Self {
65        self.null_count = null_count;
66        self
67    }
68
69    /// Sets the min/max range.
70    #[must_use]
71    pub fn with_range(mut self, min: f64, max: f64) -> Self {
72        self.min_value = Some(min);
73        self.max_value = Some(max);
74        self
75    }
76}
77
78/// Cardinality estimator.
79pub struct CardinalityEstimator {
80    /// Statistics for each label/table.
81    table_stats: HashMap<String, TableStats>,
82    /// Default row count for unknown tables.
83    default_row_count: u64,
84    /// Default selectivity for unknown predicates.
85    default_selectivity: f64,
86    /// Average edge fanout (outgoing edges per node).
87    avg_fanout: f64,
88}
89
90impl CardinalityEstimator {
91    /// Creates a new cardinality estimator with default settings.
92    #[must_use]
93    pub fn new() -> Self {
94        Self {
95            table_stats: HashMap::new(),
96            default_row_count: 1000,
97            default_selectivity: 0.1,
98            avg_fanout: 10.0,
99        }
100    }
101
102    /// Adds statistics for a table/label.
103    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
104        self.table_stats.insert(name.to_string(), stats);
105    }
106
107    /// Sets the average edge fanout.
108    pub fn set_avg_fanout(&mut self, fanout: f64) {
109        self.avg_fanout = fanout;
110    }
111
112    /// Estimates the cardinality of a logical operator.
113    #[must_use]
114    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
115        match op {
116            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
117            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
118            LogicalOperator::Project(project) => self.estimate_project(project),
119            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
120            LogicalOperator::Join(join) => self.estimate_join(join),
121            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
122            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
123            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
124            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
125            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
126            LogicalOperator::Return(ret) => self.estimate(&ret.input),
127            LogicalOperator::Empty => 0.0,
128            _ => self.default_row_count as f64,
129        }
130    }
131
132    /// Estimates node scan cardinality.
133    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
134        if let Some(label) = &scan.label {
135            if let Some(stats) = self.table_stats.get(label) {
136                return stats.row_count as f64;
137            }
138        }
139        // No label filter - scan all nodes
140        self.default_row_count as f64
141    }
142
143    /// Estimates filter cardinality.
144    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
145        let input_cardinality = self.estimate(&filter.input);
146        let selectivity = self.estimate_selectivity(&filter.predicate);
147        (input_cardinality * selectivity).max(1.0)
148    }
149
150    /// Estimates projection cardinality (same as input).
151    fn estimate_project(&self, project: &ProjectOp) -> f64 {
152        self.estimate(&project.input)
153    }
154
155    /// Estimates expand cardinality.
156    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
157        let input_cardinality = self.estimate(&expand.input);
158
159        // Apply fanout based on edge type
160        let fanout = if expand.edge_type.is_some() {
161            // Specific edge type typically has lower fanout
162            self.avg_fanout * 0.5
163        } else {
164            self.avg_fanout
165        };
166
167        // Handle variable-length paths
168        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
169            let min = expand.min_hops as f64;
170            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
171            // Geometric series approximation
172            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
173        } else {
174            fanout
175        };
176
177        (input_cardinality * path_multiplier).max(1.0)
178    }
179
180    /// Estimates join cardinality.
181    fn estimate_join(&self, join: &JoinOp) -> f64 {
182        let left_card = self.estimate(&join.left);
183        let right_card = self.estimate(&join.right);
184
185        match join.join_type {
186            JoinType::Cross => left_card * right_card,
187            JoinType::Inner => {
188                // Assume join selectivity based on conditions
189                let selectivity = if join.conditions.is_empty() {
190                    1.0 // Cross join
191                } else {
192                    // Estimate based on number of conditions
193                    0.1_f64.powi(join.conditions.len() as i32)
194                };
195                (left_card * right_card * selectivity).max(1.0)
196            }
197            JoinType::Left => {
198                // Left join returns at least all left rows
199                let inner_card = self.estimate_join(&JoinOp {
200                    left: join.left.clone(),
201                    right: join.right.clone(),
202                    join_type: JoinType::Inner,
203                    conditions: join.conditions.clone(),
204                });
205                inner_card.max(left_card)
206            }
207            JoinType::Right => {
208                // Right join returns at least all right rows
209                let inner_card = self.estimate_join(&JoinOp {
210                    left: join.left.clone(),
211                    right: join.right.clone(),
212                    join_type: JoinType::Inner,
213                    conditions: join.conditions.clone(),
214                });
215                inner_card.max(right_card)
216            }
217            JoinType::Full => {
218                // Full join returns at least max(left, right)
219                let inner_card = self.estimate_join(&JoinOp {
220                    left: join.left.clone(),
221                    right: join.right.clone(),
222                    join_type: JoinType::Inner,
223                    conditions: join.conditions.clone(),
224                });
225                inner_card.max(left_card.max(right_card))
226            }
227            JoinType::Semi => {
228                // Semi join returns at most left cardinality
229                (left_card * self.default_selectivity).max(1.0)
230            }
231            JoinType::Anti => {
232                // Anti join returns at most left cardinality
233                (left_card * (1.0 - self.default_selectivity)).max(1.0)
234            }
235        }
236    }
237
238    /// Estimates aggregation cardinality.
239    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
240        let input_cardinality = self.estimate(&agg.input);
241
242        if agg.group_by.is_empty() {
243            // Global aggregation - single row
244            1.0
245        } else {
246            // Group by - estimate distinct groups
247            // Assume each group key reduces cardinality by 10
248            let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
249            (input_cardinality / group_reduction).max(1.0)
250        }
251    }
252
253    /// Estimates sort cardinality (same as input).
254    fn estimate_sort(&self, sort: &SortOp) -> f64 {
255        self.estimate(&sort.input)
256    }
257
258    /// Estimates distinct cardinality.
259    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
260        let input_cardinality = self.estimate(&distinct.input);
261        // Assume 50% distinct by default
262        (input_cardinality * 0.5).max(1.0)
263    }
264
265    /// Estimates limit cardinality.
266    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
267        let input_cardinality = self.estimate(&limit.input);
268        (limit.count as f64).min(input_cardinality)
269    }
270
271    /// Estimates skip cardinality.
272    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
273        let input_cardinality = self.estimate(&skip.input);
274        (input_cardinality - skip.count as f64).max(0.0)
275    }
276
277    /// Estimates the selectivity of a predicate (0.0 to 1.0).
278    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
279        match expr {
280            LogicalExpression::Binary { left, op, right } => {
281                self.estimate_binary_selectivity(left, *op, right)
282            }
283            LogicalExpression::Unary { op, operand } => {
284                self.estimate_unary_selectivity(*op, operand)
285            }
286            LogicalExpression::Literal(value) => {
287                // Boolean literal
288                if let grafeo_common::types::Value::Bool(b) = value {
289                    if *b { 1.0 } else { 0.0 }
290                } else {
291                    self.default_selectivity
292                }
293            }
294            _ => self.default_selectivity,
295        }
296    }
297
298    /// Estimates binary expression selectivity.
299    fn estimate_binary_selectivity(
300        &self,
301        _left: &LogicalExpression,
302        op: BinaryOp,
303        _right: &LogicalExpression,
304    ) -> f64 {
305        match op {
306            // Equality is typically very selective
307            BinaryOp::Eq => 0.01,
308            // Inequality is very unselective
309            BinaryOp::Ne => 0.99,
310            // Range predicates
311            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => 0.33,
312            // Logical operators
313            BinaryOp::And => {
314                // AND reduces selectivity (multiply)
315                self.default_selectivity * self.default_selectivity
316            }
317            BinaryOp::Or => {
318                // OR increases selectivity (1 - (1-s1)(1-s2))
319                1.0 - (1.0 - self.default_selectivity) * (1.0 - self.default_selectivity)
320            }
321            // String operations
322            BinaryOp::StartsWith => 0.1,
323            BinaryOp::EndsWith => 0.1,
324            BinaryOp::Contains => 0.1,
325            BinaryOp::Like => 0.1,
326            // Collection membership
327            BinaryOp::In => 0.1,
328            // Other operations
329            _ => self.default_selectivity,
330        }
331    }
332
333    /// Estimates unary expression selectivity.
334    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
335        match op {
336            UnaryOp::Not => 1.0 - self.default_selectivity,
337            UnaryOp::IsNull => 0.05, // Assume 5% nulls
338            UnaryOp::IsNotNull => 0.95,
339            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
340        }
341    }
342
343    /// Gets statistics for a column.
344    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
345        self.table_stats.get(label)?.columns.get(column)
346    }
347
348    /// Estimates equality selectivity using column statistics.
349    #[allow(dead_code)]
350    fn estimate_equality_with_stats(&self, label: &str, column: &str) -> f64 {
351        if let Some(stats) = self.get_column_stats(label, column) {
352            if stats.distinct_count > 0 {
353                return 1.0 / stats.distinct_count as f64;
354            }
355        }
356        0.01 // Default for equality
357    }
358
359    /// Estimates range selectivity using column statistics.
360    #[allow(dead_code)]
361    fn estimate_range_with_stats(
362        &self,
363        label: &str,
364        column: &str,
365        lower: Option<f64>,
366        upper: Option<f64>,
367    ) -> f64 {
368        if let Some(stats) = self.get_column_stats(label, column) {
369            if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
370                let range = max - min;
371                if range <= 0.0 {
372                    return 1.0;
373                }
374
375                let effective_lower = lower.unwrap_or(min).max(min);
376                let effective_upper = upper.unwrap_or(max).min(max);
377
378                let overlap = (effective_upper - effective_lower).max(0.0);
379                return (overlap / range).min(1.0).max(0.0);
380            }
381        }
382        0.33 // Default for range
383    }
384}
385
386impl Default for CardinalityEstimator {
387    fn default() -> Self {
388        Self::new()
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use crate::query::plan::{
396        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, ProjectOp,
397        Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
398    };
399    use grafeo_common::types::Value;
400
401    #[test]
402    fn test_node_scan_with_stats() {
403        let mut estimator = CardinalityEstimator::new();
404        estimator.add_table_stats("Person", TableStats::new(5000));
405
406        let scan = LogicalOperator::NodeScan(NodeScanOp {
407            variable: "n".to_string(),
408            label: Some("Person".to_string()),
409            input: None,
410        });
411
412        let cardinality = estimator.estimate(&scan);
413        assert!((cardinality - 5000.0).abs() < 0.001);
414    }
415
416    #[test]
417    fn test_filter_reduces_cardinality() {
418        let mut estimator = CardinalityEstimator::new();
419        estimator.add_table_stats("Person", TableStats::new(1000));
420
421        let filter = LogicalOperator::Filter(FilterOp {
422            predicate: LogicalExpression::Binary {
423                left: Box::new(LogicalExpression::Property {
424                    variable: "n".to_string(),
425                    property: "age".to_string(),
426                }),
427                op: BinaryOp::Eq,
428                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
429            },
430            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
431                variable: "n".to_string(),
432                label: Some("Person".to_string()),
433                input: None,
434            })),
435        });
436
437        let cardinality = estimator.estimate(&filter);
438        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
439        assert!(cardinality < 1000.0);
440        assert!(cardinality >= 1.0);
441    }
442
443    #[test]
444    fn test_join_cardinality() {
445        let mut estimator = CardinalityEstimator::new();
446        estimator.add_table_stats("Person", TableStats::new(1000));
447        estimator.add_table_stats("Company", TableStats::new(100));
448
449        let join = LogicalOperator::Join(JoinOp {
450            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
451                variable: "p".to_string(),
452                label: Some("Person".to_string()),
453                input: None,
454            })),
455            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
456                variable: "c".to_string(),
457                label: Some("Company".to_string()),
458                input: None,
459            })),
460            join_type: JoinType::Inner,
461            conditions: vec![JoinCondition {
462                left: LogicalExpression::Property {
463                    variable: "p".to_string(),
464                    property: "company_id".to_string(),
465                },
466                right: LogicalExpression::Property {
467                    variable: "c".to_string(),
468                    property: "id".to_string(),
469                },
470            }],
471        });
472
473        let cardinality = estimator.estimate(&join);
474        // Should be less than cross product
475        assert!(cardinality < 1000.0 * 100.0);
476    }
477
478    #[test]
479    fn test_limit_caps_cardinality() {
480        let mut estimator = CardinalityEstimator::new();
481        estimator.add_table_stats("Person", TableStats::new(1000));
482
483        let limit = LogicalOperator::Limit(LimitOp {
484            count: 10,
485            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
486                variable: "n".to_string(),
487                label: Some("Person".to_string()),
488                input: None,
489            })),
490        });
491
492        let cardinality = estimator.estimate(&limit);
493        assert!((cardinality - 10.0).abs() < 0.001);
494    }
495
496    #[test]
497    fn test_aggregate_reduces_cardinality() {
498        let mut estimator = CardinalityEstimator::new();
499        estimator.add_table_stats("Person", TableStats::new(1000));
500
501        // Global aggregation
502        let global_agg = LogicalOperator::Aggregate(AggregateOp {
503            group_by: vec![],
504            aggregates: vec![],
505            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
506                variable: "n".to_string(),
507                label: Some("Person".to_string()),
508                input: None,
509            })),
510        });
511
512        let cardinality = estimator.estimate(&global_agg);
513        assert!((cardinality - 1.0).abs() < 0.001);
514
515        // Group by aggregation
516        let group_agg = LogicalOperator::Aggregate(AggregateOp {
517            group_by: vec![LogicalExpression::Property {
518                variable: "n".to_string(),
519                property: "city".to_string(),
520            }],
521            aggregates: vec![],
522            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
523                variable: "n".to_string(),
524                label: Some("Person".to_string()),
525                input: None,
526            })),
527        });
528
529        let cardinality = estimator.estimate(&group_agg);
530        // Should be less than input
531        assert!(cardinality < 1000.0);
532    }
533
534    #[test]
535    fn test_node_scan_without_stats() {
536        let estimator = CardinalityEstimator::new();
537
538        let scan = LogicalOperator::NodeScan(NodeScanOp {
539            variable: "n".to_string(),
540            label: Some("Unknown".to_string()),
541            input: None,
542        });
543
544        let cardinality = estimator.estimate(&scan);
545        // Should return default (1000)
546        assert!((cardinality - 1000.0).abs() < 0.001);
547    }
548
549    #[test]
550    fn test_node_scan_no_label() {
551        let estimator = CardinalityEstimator::new();
552
553        let scan = LogicalOperator::NodeScan(NodeScanOp {
554            variable: "n".to_string(),
555            label: None,
556            input: None,
557        });
558
559        let cardinality = estimator.estimate(&scan);
560        // Should scan all nodes (default)
561        assert!((cardinality - 1000.0).abs() < 0.001);
562    }
563
564    #[test]
565    fn test_filter_inequality_selectivity() {
566        let mut estimator = CardinalityEstimator::new();
567        estimator.add_table_stats("Person", TableStats::new(1000));
568
569        let filter = LogicalOperator::Filter(FilterOp {
570            predicate: LogicalExpression::Binary {
571                left: Box::new(LogicalExpression::Property {
572                    variable: "n".to_string(),
573                    property: "age".to_string(),
574                }),
575                op: BinaryOp::Ne,
576                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
577            },
578            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
579                variable: "n".to_string(),
580                label: Some("Person".to_string()),
581                input: None,
582            })),
583        });
584
585        let cardinality = estimator.estimate(&filter);
586        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
587        assert!(cardinality > 900.0);
588    }
589
590    #[test]
591    fn test_filter_range_selectivity() {
592        let mut estimator = CardinalityEstimator::new();
593        estimator.add_table_stats("Person", TableStats::new(1000));
594
595        let filter = LogicalOperator::Filter(FilterOp {
596            predicate: LogicalExpression::Binary {
597                left: Box::new(LogicalExpression::Property {
598                    variable: "n".to_string(),
599                    property: "age".to_string(),
600                }),
601                op: BinaryOp::Gt,
602                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
603            },
604            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
605                variable: "n".to_string(),
606                label: Some("Person".to_string()),
607                input: None,
608            })),
609        });
610
611        let cardinality = estimator.estimate(&filter);
612        // Range selectivity is 0.33, so 1000 * 0.33 = 330
613        assert!(cardinality < 500.0);
614        assert!(cardinality > 100.0);
615    }
616
617    #[test]
618    fn test_filter_and_selectivity() {
619        let mut estimator = CardinalityEstimator::new();
620        estimator.add_table_stats("Person", TableStats::new(1000));
621
622        let filter = LogicalOperator::Filter(FilterOp {
623            predicate: LogicalExpression::Binary {
624                left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
625                op: BinaryOp::And,
626                right: Box::new(LogicalExpression::Literal(Value::Bool(true))),
627            },
628            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
629                variable: "n".to_string(),
630                label: Some("Person".to_string()),
631                input: None,
632            })),
633        });
634
635        let cardinality = estimator.estimate(&filter);
636        // AND reduces selectivity (multiply)
637        assert!(cardinality < 1000.0);
638    }
639
640    #[test]
641    fn test_filter_or_selectivity() {
642        let mut estimator = CardinalityEstimator::new();
643        estimator.add_table_stats("Person", TableStats::new(1000));
644
645        let filter = LogicalOperator::Filter(FilterOp {
646            predicate: LogicalExpression::Binary {
647                left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
648                op: BinaryOp::Or,
649                right: Box::new(LogicalExpression::Literal(Value::Bool(true))),
650            },
651            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
652                variable: "n".to_string(),
653                label: Some("Person".to_string()),
654                input: None,
655            })),
656        });
657
658        let cardinality = estimator.estimate(&filter);
659        // OR increases selectivity
660        assert!(cardinality < 1000.0);
661    }
662
663    #[test]
664    fn test_filter_literal_true() {
665        let mut estimator = CardinalityEstimator::new();
666        estimator.add_table_stats("Person", TableStats::new(1000));
667
668        let filter = LogicalOperator::Filter(FilterOp {
669            predicate: LogicalExpression::Literal(Value::Bool(true)),
670            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
671                variable: "n".to_string(),
672                label: Some("Person".to_string()),
673                input: None,
674            })),
675        });
676
677        let cardinality = estimator.estimate(&filter);
678        // Literal true has selectivity 1.0
679        assert!((cardinality - 1000.0).abs() < 0.001);
680    }
681
682    #[test]
683    fn test_filter_literal_false() {
684        let mut estimator = CardinalityEstimator::new();
685        estimator.add_table_stats("Person", TableStats::new(1000));
686
687        let filter = LogicalOperator::Filter(FilterOp {
688            predicate: LogicalExpression::Literal(Value::Bool(false)),
689            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
690                variable: "n".to_string(),
691                label: Some("Person".to_string()),
692                input: None,
693            })),
694        });
695
696        let cardinality = estimator.estimate(&filter);
697        // Literal false has selectivity 0.0, but min is 1.0
698        assert!((cardinality - 1.0).abs() < 0.001);
699    }
700
701    #[test]
702    fn test_unary_not_selectivity() {
703        let mut estimator = CardinalityEstimator::new();
704        estimator.add_table_stats("Person", TableStats::new(1000));
705
706        let filter = LogicalOperator::Filter(FilterOp {
707            predicate: LogicalExpression::Unary {
708                op: UnaryOp::Not,
709                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
710            },
711            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
712                variable: "n".to_string(),
713                label: Some("Person".to_string()),
714                input: None,
715            })),
716        });
717
718        let cardinality = estimator.estimate(&filter);
719        // NOT inverts selectivity
720        assert!(cardinality < 1000.0);
721    }
722
723    #[test]
724    fn test_unary_is_null_selectivity() {
725        let mut estimator = CardinalityEstimator::new();
726        estimator.add_table_stats("Person", TableStats::new(1000));
727
728        let filter = LogicalOperator::Filter(FilterOp {
729            predicate: LogicalExpression::Unary {
730                op: UnaryOp::IsNull,
731                operand: Box::new(LogicalExpression::Variable("x".to_string())),
732            },
733            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
734                variable: "n".to_string(),
735                label: Some("Person".to_string()),
736                input: None,
737            })),
738        });
739
740        let cardinality = estimator.estimate(&filter);
741        // IS NULL has selectivity 0.05
742        assert!(cardinality < 100.0);
743    }
744
745    #[test]
746    fn test_expand_cardinality() {
747        let mut estimator = CardinalityEstimator::new();
748        estimator.add_table_stats("Person", TableStats::new(100));
749
750        let expand = LogicalOperator::Expand(ExpandOp {
751            from_variable: "a".to_string(),
752            to_variable: "b".to_string(),
753            edge_variable: None,
754            direction: ExpandDirection::Outgoing,
755            edge_type: None,
756            min_hops: 1,
757            max_hops: Some(1),
758            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
759                variable: "a".to_string(),
760                label: Some("Person".to_string()),
761                input: None,
762            })),
763        });
764
765        let cardinality = estimator.estimate(&expand);
766        // Expand multiplies by fanout (10)
767        assert!(cardinality > 100.0);
768    }
769
770    #[test]
771    fn test_expand_with_edge_type_filter() {
772        let mut estimator = CardinalityEstimator::new();
773        estimator.add_table_stats("Person", TableStats::new(100));
774
775        let expand = LogicalOperator::Expand(ExpandOp {
776            from_variable: "a".to_string(),
777            to_variable: "b".to_string(),
778            edge_variable: None,
779            direction: ExpandDirection::Outgoing,
780            edge_type: Some("KNOWS".to_string()),
781            min_hops: 1,
782            max_hops: Some(1),
783            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
784                variable: "a".to_string(),
785                label: Some("Person".to_string()),
786                input: None,
787            })),
788        });
789
790        let cardinality = estimator.estimate(&expand);
791        // With edge type, fanout is reduced by half
792        assert!(cardinality > 100.0);
793    }
794
795    #[test]
796    fn test_expand_variable_length() {
797        let mut estimator = CardinalityEstimator::new();
798        estimator.add_table_stats("Person", TableStats::new(100));
799
800        let expand = LogicalOperator::Expand(ExpandOp {
801            from_variable: "a".to_string(),
802            to_variable: "b".to_string(),
803            edge_variable: None,
804            direction: ExpandDirection::Outgoing,
805            edge_type: None,
806            min_hops: 1,
807            max_hops: Some(3),
808            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
809                variable: "a".to_string(),
810                label: Some("Person".to_string()),
811                input: None,
812            })),
813        });
814
815        let cardinality = estimator.estimate(&expand);
816        // Variable length path has much higher cardinality
817        assert!(cardinality > 500.0);
818    }
819
820    #[test]
821    fn test_join_cross_product() {
822        let mut estimator = CardinalityEstimator::new();
823        estimator.add_table_stats("Person", TableStats::new(100));
824        estimator.add_table_stats("Company", TableStats::new(50));
825
826        let join = LogicalOperator::Join(JoinOp {
827            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
828                variable: "p".to_string(),
829                label: Some("Person".to_string()),
830                input: None,
831            })),
832            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
833                variable: "c".to_string(),
834                label: Some("Company".to_string()),
835                input: None,
836            })),
837            join_type: JoinType::Cross,
838            conditions: vec![],
839        });
840
841        let cardinality = estimator.estimate(&join);
842        // Cross join = 100 * 50 = 5000
843        assert!((cardinality - 5000.0).abs() < 0.001);
844    }
845
846    #[test]
847    fn test_join_left_outer() {
848        let mut estimator = CardinalityEstimator::new();
849        estimator.add_table_stats("Person", TableStats::new(1000));
850        estimator.add_table_stats("Company", TableStats::new(10));
851
852        let join = LogicalOperator::Join(JoinOp {
853            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
854                variable: "p".to_string(),
855                label: Some("Person".to_string()),
856                input: None,
857            })),
858            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
859                variable: "c".to_string(),
860                label: Some("Company".to_string()),
861                input: None,
862            })),
863            join_type: JoinType::Left,
864            conditions: vec![JoinCondition {
865                left: LogicalExpression::Variable("p".to_string()),
866                right: LogicalExpression::Variable("c".to_string()),
867            }],
868        });
869
870        let cardinality = estimator.estimate(&join);
871        // Left join returns at least all left rows
872        assert!(cardinality >= 1000.0);
873    }
874
875    #[test]
876    fn test_join_semi() {
877        let mut estimator = CardinalityEstimator::new();
878        estimator.add_table_stats("Person", TableStats::new(1000));
879        estimator.add_table_stats("Company", TableStats::new(100));
880
881        let join = LogicalOperator::Join(JoinOp {
882            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
883                variable: "p".to_string(),
884                label: Some("Person".to_string()),
885                input: None,
886            })),
887            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
888                variable: "c".to_string(),
889                label: Some("Company".to_string()),
890                input: None,
891            })),
892            join_type: JoinType::Semi,
893            conditions: vec![],
894        });
895
896        let cardinality = estimator.estimate(&join);
897        // Semi join returns at most left cardinality
898        assert!(cardinality <= 1000.0);
899    }
900
901    #[test]
902    fn test_join_anti() {
903        let mut estimator = CardinalityEstimator::new();
904        estimator.add_table_stats("Person", TableStats::new(1000));
905        estimator.add_table_stats("Company", TableStats::new(100));
906
907        let join = LogicalOperator::Join(JoinOp {
908            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
909                variable: "p".to_string(),
910                label: Some("Person".to_string()),
911                input: None,
912            })),
913            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
914                variable: "c".to_string(),
915                label: Some("Company".to_string()),
916                input: None,
917            })),
918            join_type: JoinType::Anti,
919            conditions: vec![],
920        });
921
922        let cardinality = estimator.estimate(&join);
923        // Anti join returns at most left cardinality
924        assert!(cardinality <= 1000.0);
925        assert!(cardinality >= 1.0);
926    }
927
928    #[test]
929    fn test_project_preserves_cardinality() {
930        let mut estimator = CardinalityEstimator::new();
931        estimator.add_table_stats("Person", TableStats::new(1000));
932
933        let project = LogicalOperator::Project(ProjectOp {
934            projections: vec![Projection {
935                expression: LogicalExpression::Variable("n".to_string()),
936                alias: None,
937            }],
938            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
939                variable: "n".to_string(),
940                label: Some("Person".to_string()),
941                input: None,
942            })),
943        });
944
945        let cardinality = estimator.estimate(&project);
946        assert!((cardinality - 1000.0).abs() < 0.001);
947    }
948
949    #[test]
950    fn test_sort_preserves_cardinality() {
951        let mut estimator = CardinalityEstimator::new();
952        estimator.add_table_stats("Person", TableStats::new(1000));
953
954        let sort = LogicalOperator::Sort(SortOp {
955            keys: vec![SortKey {
956                expression: LogicalExpression::Variable("n".to_string()),
957                order: SortOrder::Ascending,
958            }],
959            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
960                variable: "n".to_string(),
961                label: Some("Person".to_string()),
962                input: None,
963            })),
964        });
965
966        let cardinality = estimator.estimate(&sort);
967        assert!((cardinality - 1000.0).abs() < 0.001);
968    }
969
970    #[test]
971    fn test_distinct_reduces_cardinality() {
972        let mut estimator = CardinalityEstimator::new();
973        estimator.add_table_stats("Person", TableStats::new(1000));
974
975        let distinct = LogicalOperator::Distinct(DistinctOp {
976            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
977                variable: "n".to_string(),
978                label: Some("Person".to_string()),
979                input: None,
980            })),
981        });
982
983        let cardinality = estimator.estimate(&distinct);
984        // Distinct assumes 50% unique
985        assert!((cardinality - 500.0).abs() < 0.001);
986    }
987
988    #[test]
989    fn test_skip_reduces_cardinality() {
990        let mut estimator = CardinalityEstimator::new();
991        estimator.add_table_stats("Person", TableStats::new(1000));
992
993        let skip = LogicalOperator::Skip(SkipOp {
994            count: 100,
995            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
996                variable: "n".to_string(),
997                label: Some("Person".to_string()),
998                input: None,
999            })),
1000        });
1001
1002        let cardinality = estimator.estimate(&skip);
1003        assert!((cardinality - 900.0).abs() < 0.001);
1004    }
1005
1006    #[test]
1007    fn test_return_preserves_cardinality() {
1008        let mut estimator = CardinalityEstimator::new();
1009        estimator.add_table_stats("Person", TableStats::new(1000));
1010
1011        let ret = LogicalOperator::Return(ReturnOp {
1012            items: vec![ReturnItem {
1013                expression: LogicalExpression::Variable("n".to_string()),
1014                alias: None,
1015            }],
1016            distinct: false,
1017            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1018                variable: "n".to_string(),
1019                label: Some("Person".to_string()),
1020                input: None,
1021            })),
1022        });
1023
1024        let cardinality = estimator.estimate(&ret);
1025        assert!((cardinality - 1000.0).abs() < 0.001);
1026    }
1027
1028    #[test]
1029    fn test_empty_cardinality() {
1030        let estimator = CardinalityEstimator::new();
1031        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1032        assert!((cardinality).abs() < 0.001);
1033    }
1034
1035    #[test]
1036    fn test_table_stats_with_column() {
1037        let stats = TableStats::new(1000).with_column(
1038            "age",
1039            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1040        );
1041
1042        assert_eq!(stats.row_count, 1000);
1043        let col = stats.columns.get("age").unwrap();
1044        assert_eq!(col.distinct_count, 50);
1045        assert_eq!(col.null_count, 10);
1046        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1047        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1048    }
1049
1050    #[test]
1051    fn test_estimator_default() {
1052        let estimator = CardinalityEstimator::default();
1053        let scan = LogicalOperator::NodeScan(NodeScanOp {
1054            variable: "n".to_string(),
1055            label: None,
1056            input: None,
1057        });
1058        let cardinality = estimator.estimate(&scan);
1059        assert!((cardinality - 1000.0).abs() < 0.001);
1060    }
1061
1062    #[test]
1063    fn test_set_avg_fanout() {
1064        let mut estimator = CardinalityEstimator::new();
1065        estimator.add_table_stats("Person", TableStats::new(100));
1066        estimator.set_avg_fanout(5.0);
1067
1068        let expand = LogicalOperator::Expand(ExpandOp {
1069            from_variable: "a".to_string(),
1070            to_variable: "b".to_string(),
1071            edge_variable: None,
1072            direction: ExpandDirection::Outgoing,
1073            edge_type: None,
1074            min_hops: 1,
1075            max_hops: Some(1),
1076            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1077                variable: "a".to_string(),
1078                label: Some("Person".to_string()),
1079                input: None,
1080            })),
1081        });
1082
1083        let cardinality = estimator.estimate(&expand);
1084        // With fanout 5: 100 * 5 = 500
1085        assert!((cardinality - 500.0).abs() < 0.001);
1086    }
1087
1088    #[test]
1089    fn test_multiple_group_by_keys_reduce_cardinality() {
1090        // The current implementation uses a simplified model where more group by keys
1091        // results in greater reduction (dividing by 10^num_keys). This is a simplification
1092        // that works for most cases where group by keys are correlated.
1093        let mut estimator = CardinalityEstimator::new();
1094        estimator.add_table_stats("Person", TableStats::new(10000));
1095
1096        let single_group = LogicalOperator::Aggregate(AggregateOp {
1097            group_by: vec![LogicalExpression::Property {
1098                variable: "n".to_string(),
1099                property: "city".to_string(),
1100            }],
1101            aggregates: vec![],
1102            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1103                variable: "n".to_string(),
1104                label: Some("Person".to_string()),
1105                input: None,
1106            })),
1107        });
1108
1109        let multi_group = LogicalOperator::Aggregate(AggregateOp {
1110            group_by: vec![
1111                LogicalExpression::Property {
1112                    variable: "n".to_string(),
1113                    property: "city".to_string(),
1114                },
1115                LogicalExpression::Property {
1116                    variable: "n".to_string(),
1117                    property: "country".to_string(),
1118                },
1119            ],
1120            aggregates: vec![],
1121            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1122                variable: "n".to_string(),
1123                label: Some("Person".to_string()),
1124                input: None,
1125            })),
1126        });
1127
1128        let single_card = estimator.estimate(&single_group);
1129        let multi_card = estimator.estimate(&multi_group);
1130
1131        // Both should reduce cardinality from input
1132        assert!(single_card < 10000.0);
1133        assert!(multi_card < 10000.0);
1134        // Both should be at least 1
1135        assert!(single_card >= 1.0);
1136        assert!(multi_card >= 1.0);
1137    }
1138}