Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4//!
5//! # Equi-Depth Histograms
6//!
7//! This module provides equi-depth histogram support for accurate selectivity
8//! estimation of range predicates. Unlike equi-width histograms that divide
9//! the value range into equal-sized buckets, equi-depth histograms divide
10//! the data into buckets with approximately equal numbers of rows.
11//!
12//! Benefits:
13//! - Better estimates for skewed data distributions
14//! - More accurate range selectivity than assuming uniform distribution
15//! - Adaptive to actual data characteristics
16
17use crate::query::plan::{
18    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19    LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
20    VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24/// A bucket in an equi-depth histogram.
25///
26/// Each bucket represents a range of values and the frequency of rows
27/// falling within that range. In an equi-depth histogram, all buckets
28/// contain approximately the same number of rows.
29#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31    /// Lower bound of the bucket (inclusive).
32    pub lower_bound: f64,
33    /// Upper bound of the bucket (exclusive, except for the last bucket).
34    pub upper_bound: f64,
35    /// Number of rows in this bucket.
36    pub frequency: u64,
37    /// Number of distinct values in this bucket.
38    pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42    /// Creates a new histogram bucket.
43    #[must_use]
44    pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45        Self {
46            lower_bound,
47            upper_bound,
48            frequency,
49            distinct_count,
50        }
51    }
52
53    /// Returns the width of this bucket.
54    #[must_use]
55    pub fn width(&self) -> f64 {
56        self.upper_bound - self.lower_bound
57    }
58
59    /// Checks if a value falls within this bucket.
60    #[must_use]
61    pub fn contains(&self, value: f64) -> bool {
62        value >= self.lower_bound && value < self.upper_bound
63    }
64
65    /// Returns the fraction of this bucket covered by the given range.
66    #[must_use]
67    pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68        let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69        let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71        let bucket_width = self.width();
72        if bucket_width <= 0.0 {
73            return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74                1.0
75            } else {
76                0.0
77            };
78        }
79
80        let overlap = (effective_upper - effective_lower).max(0.0);
81        (overlap / bucket_width).min(1.0)
82    }
83}
84
85/// An equi-depth histogram for selectivity estimation.
86///
87/// Equi-depth histograms partition data into buckets where each bucket
88/// contains approximately the same number of rows. This provides more
89/// accurate selectivity estimates than assuming uniform distribution,
90/// especially for skewed data.
91///
92/// # Example
93///
94/// ```ignore
95/// use grafeo_engine::query::optimizer::cardinality::EquiDepthHistogram;
96///
97/// // Build a histogram from sorted values
98/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0];
99/// let histogram = EquiDepthHistogram::build(&values, 4);
100///
101/// // Estimate selectivity for age > 25
102/// let selectivity = histogram.range_selectivity(Some(25.0), None);
103/// ```
104#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106    /// The histogram buckets, sorted by lower_bound.
107    buckets: Vec<HistogramBucket>,
108    /// Total number of rows represented by this histogram.
109    total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113    /// Creates a new histogram from pre-built buckets.
114    #[must_use]
115    pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116        let total_rows = buckets.iter().map(|b| b.frequency).sum();
117        Self {
118            buckets,
119            total_rows,
120        }
121    }
122
123    /// Builds an equi-depth histogram from a sorted slice of values.
124    ///
125    /// # Arguments
126    /// * `values` - A sorted slice of numeric values
127    /// * `num_buckets` - The desired number of buckets
128    ///
129    /// # Returns
130    /// An equi-depth histogram with approximately equal row counts per bucket.
131    #[must_use]
132    pub fn build(values: &[f64], num_buckets: usize) -> Self {
133        if values.is_empty() || num_buckets == 0 {
134            return Self {
135                buckets: Vec::new(),
136                total_rows: 0,
137            };
138        }
139
140        let num_buckets = num_buckets.min(values.len());
141        let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142        let mut buckets = Vec::with_capacity(num_buckets);
143
144        let mut start_idx = 0;
145        while start_idx < values.len() {
146            let end_idx = (start_idx + rows_per_bucket).min(values.len());
147            let lower_bound = values[start_idx];
148            let upper_bound = if end_idx < values.len() {
149                values[end_idx]
150            } else {
151                // For the last bucket, extend slightly beyond the max value
152                values[end_idx - 1] + 1.0
153            };
154
155            // Count distinct values in this bucket
156            let bucket_values = &values[start_idx..end_idx];
157            let distinct_count = count_distinct(bucket_values);
158
159            buckets.push(HistogramBucket::new(
160                lower_bound,
161                upper_bound,
162                (end_idx - start_idx) as u64,
163                distinct_count,
164            ));
165
166            start_idx = end_idx;
167        }
168
169        Self::new(buckets)
170    }
171
172    /// Returns the number of buckets in this histogram.
173    #[must_use]
174    pub fn num_buckets(&self) -> usize {
175        self.buckets.len()
176    }
177
178    /// Returns the total number of rows represented.
179    #[must_use]
180    pub fn total_rows(&self) -> u64 {
181        self.total_rows
182    }
183
184    /// Returns the histogram buckets.
185    #[must_use]
186    pub fn buckets(&self) -> &[HistogramBucket] {
187        &self.buckets
188    }
189
190    /// Estimates selectivity for a range predicate.
191    ///
192    /// # Arguments
193    /// * `lower` - Lower bound (None for unbounded)
194    /// * `upper` - Upper bound (None for unbounded)
195    ///
196    /// # Returns
197    /// Estimated fraction of rows matching the range (0.0 to 1.0).
198    #[must_use]
199    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200        if self.buckets.is_empty() || self.total_rows == 0 {
201            return 0.33; // Default fallback
202        }
203
204        let mut matching_rows = 0.0;
205
206        for bucket in &self.buckets {
207            // Check if this bucket overlaps with the range
208            let bucket_lower = bucket.lower_bound;
209            let bucket_upper = bucket.upper_bound;
210
211            // Skip buckets entirely outside the range
212            if let Some(l) = lower
213                && bucket_upper <= l
214            {
215                continue;
216            }
217            if let Some(u) = upper
218                && bucket_lower >= u
219            {
220                continue;
221            }
222
223            // Calculate the fraction of this bucket covered by the range
224            let overlap = bucket.overlap_fraction(lower, upper);
225            matching_rows += overlap * bucket.frequency as f64;
226        }
227
228        (matching_rows / self.total_rows as f64).min(1.0).max(0.0)
229    }
230
231    /// Estimates selectivity for an equality predicate.
232    ///
233    /// Uses the distinct count within matching buckets for better accuracy.
234    #[must_use]
235    pub fn equality_selectivity(&self, value: f64) -> f64 {
236        if self.buckets.is_empty() || self.total_rows == 0 {
237            return 0.01; // Default fallback
238        }
239
240        // Find the bucket containing this value
241        for bucket in &self.buckets {
242            if bucket.contains(value) {
243                // Assume uniform distribution within bucket
244                if bucket.distinct_count > 0 {
245                    return (bucket.frequency as f64
246                        / bucket.distinct_count as f64
247                        / self.total_rows as f64)
248                        .min(1.0);
249                }
250            }
251        }
252
253        // Value not in any bucket - very low selectivity
254        0.001
255    }
256
257    /// Gets the minimum value in the histogram.
258    #[must_use]
259    pub fn min_value(&self) -> Option<f64> {
260        self.buckets.first().map(|b| b.lower_bound)
261    }
262
263    /// Gets the maximum value in the histogram.
264    #[must_use]
265    pub fn max_value(&self) -> Option<f64> {
266        self.buckets.last().map(|b| b.upper_bound)
267    }
268}
269
270/// Counts distinct values in a sorted slice.
271fn count_distinct(sorted_values: &[f64]) -> u64 {
272    if sorted_values.is_empty() {
273        return 0;
274    }
275
276    let mut count = 1u64;
277    let mut prev = sorted_values[0];
278
279    for &val in &sorted_values[1..] {
280        if (val - prev).abs() > f64::EPSILON {
281            count += 1;
282            prev = val;
283        }
284    }
285
286    count
287}
288
289/// Statistics for a table/label.
290#[derive(Debug, Clone)]
291pub struct TableStats {
292    /// Total number of rows.
293    pub row_count: u64,
294    /// Column statistics.
295    pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299    /// Creates new table statistics.
300    #[must_use]
301    pub fn new(row_count: u64) -> Self {
302        Self {
303            row_count,
304            columns: HashMap::new(),
305        }
306    }
307
308    /// Adds column statistics.
309    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310        self.columns.insert(name.to_string(), stats);
311        self
312    }
313}
314
315/// Statistics for a column.
316#[derive(Debug, Clone)]
317pub struct ColumnStats {
318    /// Number of distinct values.
319    pub distinct_count: u64,
320    /// Number of null values.
321    pub null_count: u64,
322    /// Minimum value (if orderable).
323    pub min_value: Option<f64>,
324    /// Maximum value (if orderable).
325    pub max_value: Option<f64>,
326    /// Equi-depth histogram for accurate selectivity estimation.
327    pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331    /// Creates new column statistics.
332    #[must_use]
333    pub fn new(distinct_count: u64) -> Self {
334        Self {
335            distinct_count,
336            null_count: 0,
337            min_value: None,
338            max_value: None,
339            histogram: None,
340        }
341    }
342
343    /// Sets the null count.
344    #[must_use]
345    pub fn with_nulls(mut self, null_count: u64) -> Self {
346        self.null_count = null_count;
347        self
348    }
349
350    /// Sets the min/max range.
351    #[must_use]
352    pub fn with_range(mut self, min: f64, max: f64) -> Self {
353        self.min_value = Some(min);
354        self.max_value = Some(max);
355        self
356    }
357
358    /// Sets the equi-depth histogram for this column.
359    #[must_use]
360    pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361        self.histogram = Some(histogram);
362        self
363    }
364
365    /// Builds column statistics with histogram from raw values.
366    ///
367    /// This is a convenience method that computes all statistics from the data.
368    ///
369    /// # Arguments
370    /// * `values` - The column values (will be sorted internally)
371    /// * `num_buckets` - Number of histogram buckets to create
372    #[must_use]
373    pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374        if values.is_empty() {
375            return Self::new(0);
376        }
377
378        // Sort values for histogram building
379        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381        let min = values.first().copied();
382        let max = values.last().copied();
383        let distinct_count = count_distinct(&values);
384        let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386        Self {
387            distinct_count,
388            null_count: 0,
389            min_value: min,
390            max_value: max,
391            histogram: Some(histogram),
392        }
393    }
394}
395
396/// Cardinality estimator.
397pub struct CardinalityEstimator {
398    /// Statistics for each label/table.
399    table_stats: HashMap<String, TableStats>,
400    /// Default row count for unknown tables.
401    default_row_count: u64,
402    /// Default selectivity for unknown predicates.
403    default_selectivity: f64,
404    /// Average edge fanout (outgoing edges per node).
405    avg_fanout: f64,
406}
407
408impl CardinalityEstimator {
409    /// Creates a new cardinality estimator with default settings.
410    #[must_use]
411    pub fn new() -> Self {
412        Self {
413            table_stats: HashMap::new(),
414            default_row_count: 1000,
415            default_selectivity: 0.1,
416            avg_fanout: 10.0,
417        }
418    }
419
420    /// Creates a cardinality estimator pre-populated from store statistics.
421    ///
422    /// Maps `LabelStatistics` to `TableStats` and computes the average edge
423    /// fanout from `EdgeTypeStatistics`. Falls back to defaults for any
424    /// missing statistics.
425    #[must_use]
426    pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
427        let mut estimator = Self::new();
428
429        // Use total node count as default for unlabeled scans
430        if stats.total_nodes > 0 {
431            estimator.default_row_count = stats.total_nodes;
432        }
433
434        // Convert label statistics to optimizer table stats
435        for (label, label_stats) in &stats.labels {
436            let mut table_stats = TableStats::new(label_stats.node_count);
437
438            // Map property statistics (distinct count for selectivity estimation)
439            for (prop, col_stats) in &label_stats.properties {
440                let optimizer_col =
441                    ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
442                table_stats = table_stats.with_column(prop, optimizer_col);
443            }
444
445            estimator.add_table_stats(label, table_stats);
446        }
447
448        // Compute average fanout from edge type statistics
449        if !stats.edge_types.is_empty() {
450            let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
451            estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
452        } else if stats.total_nodes > 0 {
453            estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
454        }
455
456        // Clamp fanout to a reasonable minimum
457        if estimator.avg_fanout < 1.0 {
458            estimator.avg_fanout = 1.0;
459        }
460
461        estimator
462    }
463
464    /// Adds statistics for a table/label.
465    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
466        self.table_stats.insert(name.to_string(), stats);
467    }
468
469    /// Sets the average edge fanout.
470    pub fn set_avg_fanout(&mut self, fanout: f64) {
471        self.avg_fanout = fanout;
472    }
473
474    /// Estimates the cardinality of a logical operator.
475    #[must_use]
476    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
477        match op {
478            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
479            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
480            LogicalOperator::Project(project) => self.estimate_project(project),
481            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
482            LogicalOperator::Join(join) => self.estimate_join(join),
483            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
484            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
485            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
486            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
487            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
488            LogicalOperator::Return(ret) => self.estimate(&ret.input),
489            LogicalOperator::Empty => 0.0,
490            LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
491            LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
492            _ => self.default_row_count as f64,
493        }
494    }
495
496    /// Estimates node scan cardinality.
497    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
498        if let Some(label) = &scan.label
499            && let Some(stats) = self.table_stats.get(label)
500        {
501            return stats.row_count as f64;
502        }
503        // No label filter - scan all nodes
504        self.default_row_count as f64
505    }
506
507    /// Estimates filter cardinality.
508    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
509        let input_cardinality = self.estimate(&filter.input);
510        let selectivity = self.estimate_selectivity(&filter.predicate);
511        (input_cardinality * selectivity).max(1.0)
512    }
513
514    /// Estimates projection cardinality (same as input).
515    fn estimate_project(&self, project: &ProjectOp) -> f64 {
516        self.estimate(&project.input)
517    }
518
519    /// Estimates expand cardinality.
520    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
521        let input_cardinality = self.estimate(&expand.input);
522
523        // Apply fanout based on edge type
524        let fanout = if expand.edge_type.is_some() {
525            // Specific edge type typically has lower fanout
526            self.avg_fanout * 0.5
527        } else {
528            self.avg_fanout
529        };
530
531        // Handle variable-length paths
532        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
533            let min = expand.min_hops as f64;
534            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
535            // Geometric series approximation
536            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
537        } else {
538            fanout
539        };
540
541        (input_cardinality * path_multiplier).max(1.0)
542    }
543
544    /// Estimates join cardinality.
545    fn estimate_join(&self, join: &JoinOp) -> f64 {
546        let left_card = self.estimate(&join.left);
547        let right_card = self.estimate(&join.right);
548
549        match join.join_type {
550            JoinType::Cross => left_card * right_card,
551            JoinType::Inner => {
552                // Assume join selectivity based on conditions
553                let selectivity = if join.conditions.is_empty() {
554                    1.0 // Cross join
555                } else {
556                    // Estimate based on number of conditions
557                    0.1_f64.powi(join.conditions.len() as i32)
558                };
559                (left_card * right_card * selectivity).max(1.0)
560            }
561            JoinType::Left => {
562                // Left join returns at least all left rows
563                let inner_card = self.estimate_join(&JoinOp {
564                    left: join.left.clone(),
565                    right: join.right.clone(),
566                    join_type: JoinType::Inner,
567                    conditions: join.conditions.clone(),
568                });
569                inner_card.max(left_card)
570            }
571            JoinType::Right => {
572                // Right join returns at least all right rows
573                let inner_card = self.estimate_join(&JoinOp {
574                    left: join.left.clone(),
575                    right: join.right.clone(),
576                    join_type: JoinType::Inner,
577                    conditions: join.conditions.clone(),
578                });
579                inner_card.max(right_card)
580            }
581            JoinType::Full => {
582                // Full join returns at least max(left, right)
583                let inner_card = self.estimate_join(&JoinOp {
584                    left: join.left.clone(),
585                    right: join.right.clone(),
586                    join_type: JoinType::Inner,
587                    conditions: join.conditions.clone(),
588                });
589                inner_card.max(left_card.max(right_card))
590            }
591            JoinType::Semi => {
592                // Semi join returns at most left cardinality
593                (left_card * self.default_selectivity).max(1.0)
594            }
595            JoinType::Anti => {
596                // Anti join returns at most left cardinality
597                (left_card * (1.0 - self.default_selectivity)).max(1.0)
598            }
599        }
600    }
601
602    /// Estimates aggregation cardinality.
603    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
604        let input_cardinality = self.estimate(&agg.input);
605
606        if agg.group_by.is_empty() {
607            // Global aggregation - single row
608            1.0
609        } else {
610            // Group by - estimate distinct groups
611            // Assume each group key reduces cardinality by 10
612            let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
613            (input_cardinality / group_reduction).max(1.0)
614        }
615    }
616
617    /// Estimates sort cardinality (same as input).
618    fn estimate_sort(&self, sort: &SortOp) -> f64 {
619        self.estimate(&sort.input)
620    }
621
622    /// Estimates distinct cardinality.
623    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
624        let input_cardinality = self.estimate(&distinct.input);
625        // Assume 50% distinct by default
626        (input_cardinality * 0.5).max(1.0)
627    }
628
629    /// Estimates limit cardinality.
630    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
631        let input_cardinality = self.estimate(&limit.input);
632        (limit.count as f64).min(input_cardinality)
633    }
634
635    /// Estimates skip cardinality.
636    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
637        let input_cardinality = self.estimate(&skip.input);
638        (input_cardinality - skip.count as f64).max(0.0)
639    }
640
641    /// Estimates vector scan cardinality.
642    ///
643    /// Vector scan returns at most k results (the k nearest neighbors).
644    /// With similarity/distance filters, it may return fewer.
645    fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
646        let base_k = scan.k as f64;
647
648        // Apply filter selectivity if thresholds are specified
649        let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
650            // Assume 70% of results pass threshold filters
651            0.7
652        } else {
653            1.0
654        };
655
656        (base_k * selectivity).max(1.0)
657    }
658
659    /// Estimates vector join cardinality.
660    ///
661    /// Vector join produces up to k results per input row.
662    fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
663        let input_cardinality = self.estimate(&join.input);
664        let k = join.k as f64;
665
666        // Apply filter selectivity if thresholds are specified
667        let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
668            0.7
669        } else {
670            1.0
671        };
672
673        (input_cardinality * k * selectivity).max(1.0)
674    }
675
676    /// Estimates the selectivity of a predicate (0.0 to 1.0).
677    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
678        match expr {
679            LogicalExpression::Binary { left, op, right } => {
680                self.estimate_binary_selectivity(left, *op, right)
681            }
682            LogicalExpression::Unary { op, operand } => {
683                self.estimate_unary_selectivity(*op, operand)
684            }
685            LogicalExpression::Literal(value) => {
686                // Boolean literal
687                if let grafeo_common::types::Value::Bool(b) = value {
688                    if *b { 1.0 } else { 0.0 }
689                } else {
690                    self.default_selectivity
691                }
692            }
693            _ => self.default_selectivity,
694        }
695    }
696
697    /// Estimates binary expression selectivity.
698    fn estimate_binary_selectivity(
699        &self,
700        left: &LogicalExpression,
701        op: BinaryOp,
702        right: &LogicalExpression,
703    ) -> f64 {
704        match op {
705            // Equality - try histogram-based estimation
706            BinaryOp::Eq => {
707                if let Some(selectivity) = self.try_equality_selectivity(left, right) {
708                    return selectivity;
709                }
710                0.01
711            }
712            // Inequality is very unselective
713            BinaryOp::Ne => 0.99,
714            // Range predicates - use histogram if available
715            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
716                if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
717                    return selectivity;
718                }
719                0.33
720            }
721            // Logical operators - recursively estimate sub-expressions
722            BinaryOp::And => {
723                let left_sel = self.estimate_selectivity(left);
724                let right_sel = self.estimate_selectivity(right);
725                // AND reduces selectivity (multiply assuming independence)
726                left_sel * right_sel
727            }
728            BinaryOp::Or => {
729                let left_sel = self.estimate_selectivity(left);
730                let right_sel = self.estimate_selectivity(right);
731                // OR: P(A ∪ B) = P(A) + P(B) - P(A ∩ B)
732                // Assuming independence: P(A ∩ B) = P(A) * P(B)
733                (left_sel + right_sel - left_sel * right_sel).min(1.0)
734            }
735            // String operations
736            BinaryOp::StartsWith => 0.1,
737            BinaryOp::EndsWith => 0.1,
738            BinaryOp::Contains => 0.1,
739            BinaryOp::Like => 0.1,
740            // Collection membership
741            BinaryOp::In => 0.1,
742            // Other operations
743            _ => self.default_selectivity,
744        }
745    }
746
747    /// Tries to estimate equality selectivity using histograms.
748    fn try_equality_selectivity(
749        &self,
750        left: &LogicalExpression,
751        right: &LogicalExpression,
752    ) -> Option<f64> {
753        // Extract property access and literal value
754        let (label, column, value) = self.extract_column_and_value(left, right)?;
755
756        // Get column stats with histogram
757        let stats = self.get_column_stats(&label, &column)?;
758
759        // Try histogram-based estimation
760        if let Some(ref histogram) = stats.histogram {
761            return Some(histogram.equality_selectivity(value));
762        }
763
764        // Fall back to distinct count estimation
765        if stats.distinct_count > 0 {
766            return Some(1.0 / stats.distinct_count as f64);
767        }
768
769        None
770    }
771
772    /// Tries to estimate range selectivity using histograms.
773    fn try_range_selectivity(
774        &self,
775        left: &LogicalExpression,
776        op: BinaryOp,
777        right: &LogicalExpression,
778    ) -> Option<f64> {
779        // Extract property access and literal value
780        let (label, column, value) = self.extract_column_and_value(left, right)?;
781
782        // Get column stats
783        let stats = self.get_column_stats(&label, &column)?;
784
785        // Determine the range based on operator
786        let (lower, upper) = match op {
787            BinaryOp::Lt => (None, Some(value)),
788            BinaryOp::Le => (None, Some(value + f64::EPSILON)),
789            BinaryOp::Gt => (Some(value + f64::EPSILON), None),
790            BinaryOp::Ge => (Some(value), None),
791            _ => return None,
792        };
793
794        // Try histogram-based estimation first
795        if let Some(ref histogram) = stats.histogram {
796            return Some(histogram.range_selectivity(lower, upper));
797        }
798
799        // Fall back to min/max range estimation
800        if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
801            let range = max - min;
802            if range <= 0.0 {
803                return Some(1.0);
804            }
805
806            let effective_lower = lower.unwrap_or(min).max(min);
807            let effective_upper = upper.unwrap_or(max).min(max);
808            let overlap = (effective_upper - effective_lower).max(0.0);
809            return Some((overlap / range).min(1.0).max(0.0));
810        }
811
812        None
813    }
814
815    /// Extracts column information and literal value from a comparison.
816    ///
817    /// Returns (label, column_name, numeric_value) if the expression is
818    /// a comparison between a property access and a numeric literal.
819    fn extract_column_and_value(
820        &self,
821        left: &LogicalExpression,
822        right: &LogicalExpression,
823    ) -> Option<(String, String, f64)> {
824        // Try left as property, right as literal
825        if let Some(result) = self.try_extract_property_literal(left, right) {
826            return Some(result);
827        }
828
829        // Try right as property, left as literal
830        self.try_extract_property_literal(right, left)
831    }
832
833    /// Tries to extract property and literal from a specific ordering.
834    fn try_extract_property_literal(
835        &self,
836        property_expr: &LogicalExpression,
837        literal_expr: &LogicalExpression,
838    ) -> Option<(String, String, f64)> {
839        // Extract property access
840        let (variable, property) = match property_expr {
841            LogicalExpression::Property { variable, property } => {
842                (variable.clone(), property.clone())
843            }
844            _ => return None,
845        };
846
847        // Extract numeric literal
848        let value = match literal_expr {
849            LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
850            LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
851            _ => return None,
852        };
853
854        // Try to find a label for this variable from table stats
855        // Use the variable name as a heuristic label lookup
856        // In practice, the optimizer would track which labels variables are bound to
857        for label in self.table_stats.keys() {
858            if let Some(stats) = self.table_stats.get(label)
859                && stats.columns.contains_key(&property)
860            {
861                return Some((label.clone(), property, value));
862            }
863        }
864
865        // If no stats found but we have the property, return with variable as label
866        Some((variable, property, value))
867    }
868
869    /// Estimates unary expression selectivity.
870    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
871        match op {
872            UnaryOp::Not => 1.0 - self.default_selectivity,
873            UnaryOp::IsNull => 0.05, // Assume 5% nulls
874            UnaryOp::IsNotNull => 0.95,
875            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
876        }
877    }
878
879    /// Gets statistics for a column.
880    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
881        self.table_stats.get(label)?.columns.get(column)
882    }
883
884    /// Estimates equality selectivity using column statistics.
885    #[allow(dead_code)]
886    fn estimate_equality_with_stats(&self, label: &str, column: &str) -> f64 {
887        if let Some(stats) = self.get_column_stats(label, column)
888            && stats.distinct_count > 0
889        {
890            return 1.0 / stats.distinct_count as f64;
891        }
892        0.01 // Default for equality
893    }
894
895    /// Estimates range selectivity using column statistics.
896    #[allow(dead_code)]
897    fn estimate_range_with_stats(
898        &self,
899        label: &str,
900        column: &str,
901        lower: Option<f64>,
902        upper: Option<f64>,
903    ) -> f64 {
904        if let Some(stats) = self.get_column_stats(label, column)
905            && let (Some(min), Some(max)) = (stats.min_value, stats.max_value)
906        {
907            let range = max - min;
908            if range <= 0.0 {
909                return 1.0;
910            }
911
912            let effective_lower = lower.unwrap_or(min).max(min);
913            let effective_upper = upper.unwrap_or(max).min(max);
914
915            let overlap = (effective_upper - effective_lower).max(0.0);
916            return (overlap / range).min(1.0).max(0.0);
917        }
918        0.33 // Default for range
919    }
920}
921
922impl Default for CardinalityEstimator {
923    fn default() -> Self {
924        Self::new()
925    }
926}
927
928#[cfg(test)]
929mod tests {
930    use super::*;
931    use crate::query::plan::{
932        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, ProjectOp,
933        Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
934    };
935    use grafeo_common::types::Value;
936
937    #[test]
938    fn test_node_scan_with_stats() {
939        let mut estimator = CardinalityEstimator::new();
940        estimator.add_table_stats("Person", TableStats::new(5000));
941
942        let scan = LogicalOperator::NodeScan(NodeScanOp {
943            variable: "n".to_string(),
944            label: Some("Person".to_string()),
945            input: None,
946        });
947
948        let cardinality = estimator.estimate(&scan);
949        assert!((cardinality - 5000.0).abs() < 0.001);
950    }
951
952    #[test]
953    fn test_filter_reduces_cardinality() {
954        let mut estimator = CardinalityEstimator::new();
955        estimator.add_table_stats("Person", TableStats::new(1000));
956
957        let filter = LogicalOperator::Filter(FilterOp {
958            predicate: LogicalExpression::Binary {
959                left: Box::new(LogicalExpression::Property {
960                    variable: "n".to_string(),
961                    property: "age".to_string(),
962                }),
963                op: BinaryOp::Eq,
964                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
965            },
966            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
967                variable: "n".to_string(),
968                label: Some("Person".to_string()),
969                input: None,
970            })),
971        });
972
973        let cardinality = estimator.estimate(&filter);
974        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
975        assert!(cardinality < 1000.0);
976        assert!(cardinality >= 1.0);
977    }
978
979    #[test]
980    fn test_join_cardinality() {
981        let mut estimator = CardinalityEstimator::new();
982        estimator.add_table_stats("Person", TableStats::new(1000));
983        estimator.add_table_stats("Company", TableStats::new(100));
984
985        let join = LogicalOperator::Join(JoinOp {
986            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
987                variable: "p".to_string(),
988                label: Some("Person".to_string()),
989                input: None,
990            })),
991            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
992                variable: "c".to_string(),
993                label: Some("Company".to_string()),
994                input: None,
995            })),
996            join_type: JoinType::Inner,
997            conditions: vec![JoinCondition {
998                left: LogicalExpression::Property {
999                    variable: "p".to_string(),
1000                    property: "company_id".to_string(),
1001                },
1002                right: LogicalExpression::Property {
1003                    variable: "c".to_string(),
1004                    property: "id".to_string(),
1005                },
1006            }],
1007        });
1008
1009        let cardinality = estimator.estimate(&join);
1010        // Should be less than cross product
1011        assert!(cardinality < 1000.0 * 100.0);
1012    }
1013
1014    #[test]
1015    fn test_limit_caps_cardinality() {
1016        let mut estimator = CardinalityEstimator::new();
1017        estimator.add_table_stats("Person", TableStats::new(1000));
1018
1019        let limit = LogicalOperator::Limit(LimitOp {
1020            count: 10,
1021            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1022                variable: "n".to_string(),
1023                label: Some("Person".to_string()),
1024                input: None,
1025            })),
1026        });
1027
1028        let cardinality = estimator.estimate(&limit);
1029        assert!((cardinality - 10.0).abs() < 0.001);
1030    }
1031
1032    #[test]
1033    fn test_aggregate_reduces_cardinality() {
1034        let mut estimator = CardinalityEstimator::new();
1035        estimator.add_table_stats("Person", TableStats::new(1000));
1036
1037        // Global aggregation
1038        let global_agg = LogicalOperator::Aggregate(AggregateOp {
1039            group_by: vec![],
1040            aggregates: vec![],
1041            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1042                variable: "n".to_string(),
1043                label: Some("Person".to_string()),
1044                input: None,
1045            })),
1046            having: None,
1047        });
1048
1049        let cardinality = estimator.estimate(&global_agg);
1050        assert!((cardinality - 1.0).abs() < 0.001);
1051
1052        // Group by aggregation
1053        let group_agg = LogicalOperator::Aggregate(AggregateOp {
1054            group_by: vec![LogicalExpression::Property {
1055                variable: "n".to_string(),
1056                property: "city".to_string(),
1057            }],
1058            aggregates: vec![],
1059            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1060                variable: "n".to_string(),
1061                label: Some("Person".to_string()),
1062                input: None,
1063            })),
1064            having: None,
1065        });
1066
1067        let cardinality = estimator.estimate(&group_agg);
1068        // Should be less than input
1069        assert!(cardinality < 1000.0);
1070    }
1071
1072    #[test]
1073    fn test_node_scan_without_stats() {
1074        let estimator = CardinalityEstimator::new();
1075
1076        let scan = LogicalOperator::NodeScan(NodeScanOp {
1077            variable: "n".to_string(),
1078            label: Some("Unknown".to_string()),
1079            input: None,
1080        });
1081
1082        let cardinality = estimator.estimate(&scan);
1083        // Should return default (1000)
1084        assert!((cardinality - 1000.0).abs() < 0.001);
1085    }
1086
1087    #[test]
1088    fn test_node_scan_no_label() {
1089        let estimator = CardinalityEstimator::new();
1090
1091        let scan = LogicalOperator::NodeScan(NodeScanOp {
1092            variable: "n".to_string(),
1093            label: None,
1094            input: None,
1095        });
1096
1097        let cardinality = estimator.estimate(&scan);
1098        // Should scan all nodes (default)
1099        assert!((cardinality - 1000.0).abs() < 0.001);
1100    }
1101
1102    #[test]
1103    fn test_filter_inequality_selectivity() {
1104        let mut estimator = CardinalityEstimator::new();
1105        estimator.add_table_stats("Person", TableStats::new(1000));
1106
1107        let filter = LogicalOperator::Filter(FilterOp {
1108            predicate: LogicalExpression::Binary {
1109                left: Box::new(LogicalExpression::Property {
1110                    variable: "n".to_string(),
1111                    property: "age".to_string(),
1112                }),
1113                op: BinaryOp::Ne,
1114                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1115            },
1116            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1117                variable: "n".to_string(),
1118                label: Some("Person".to_string()),
1119                input: None,
1120            })),
1121        });
1122
1123        let cardinality = estimator.estimate(&filter);
1124        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
1125        assert!(cardinality > 900.0);
1126    }
1127
1128    #[test]
1129    fn test_filter_range_selectivity() {
1130        let mut estimator = CardinalityEstimator::new();
1131        estimator.add_table_stats("Person", TableStats::new(1000));
1132
1133        let filter = LogicalOperator::Filter(FilterOp {
1134            predicate: LogicalExpression::Binary {
1135                left: Box::new(LogicalExpression::Property {
1136                    variable: "n".to_string(),
1137                    property: "age".to_string(),
1138                }),
1139                op: BinaryOp::Gt,
1140                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1141            },
1142            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1143                variable: "n".to_string(),
1144                label: Some("Person".to_string()),
1145                input: None,
1146            })),
1147        });
1148
1149        let cardinality = estimator.estimate(&filter);
1150        // Range selectivity is 0.33, so 1000 * 0.33 = 330
1151        assert!(cardinality < 500.0);
1152        assert!(cardinality > 100.0);
1153    }
1154
1155    #[test]
1156    fn test_filter_and_selectivity() {
1157        let mut estimator = CardinalityEstimator::new();
1158        estimator.add_table_stats("Person", TableStats::new(1000));
1159
1160        // Test AND with two equality predicates
1161        // Each equality has selectivity 0.01, so AND gives 0.01 * 0.01 = 0.0001
1162        let filter = LogicalOperator::Filter(FilterOp {
1163            predicate: LogicalExpression::Binary {
1164                left: Box::new(LogicalExpression::Binary {
1165                    left: Box::new(LogicalExpression::Property {
1166                        variable: "n".to_string(),
1167                        property: "city".to_string(),
1168                    }),
1169                    op: BinaryOp::Eq,
1170                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1171                }),
1172                op: BinaryOp::And,
1173                right: Box::new(LogicalExpression::Binary {
1174                    left: Box::new(LogicalExpression::Property {
1175                        variable: "n".to_string(),
1176                        property: "age".to_string(),
1177                    }),
1178                    op: BinaryOp::Eq,
1179                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1180                }),
1181            },
1182            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1183                variable: "n".to_string(),
1184                label: Some("Person".to_string()),
1185                input: None,
1186            })),
1187        });
1188
1189        let cardinality = estimator.estimate(&filter);
1190        // AND reduces selectivity (multiply): 0.01 * 0.01 = 0.0001
1191        // 1000 * 0.0001 = 0.1, min is 1.0
1192        assert!(cardinality < 100.0);
1193        assert!(cardinality >= 1.0);
1194    }
1195
1196    #[test]
1197    fn test_filter_or_selectivity() {
1198        let mut estimator = CardinalityEstimator::new();
1199        estimator.add_table_stats("Person", TableStats::new(1000));
1200
1201        // Test OR with two equality predicates
1202        // Each equality has selectivity 0.01
1203        // OR gives: 0.01 + 0.01 - (0.01 * 0.01) = 0.0199
1204        let filter = LogicalOperator::Filter(FilterOp {
1205            predicate: LogicalExpression::Binary {
1206                left: Box::new(LogicalExpression::Binary {
1207                    left: Box::new(LogicalExpression::Property {
1208                        variable: "n".to_string(),
1209                        property: "city".to_string(),
1210                    }),
1211                    op: BinaryOp::Eq,
1212                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1213                }),
1214                op: BinaryOp::Or,
1215                right: Box::new(LogicalExpression::Binary {
1216                    left: Box::new(LogicalExpression::Property {
1217                        variable: "n".to_string(),
1218                        property: "city".to_string(),
1219                    }),
1220                    op: BinaryOp::Eq,
1221                    right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1222                }),
1223            },
1224            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1225                variable: "n".to_string(),
1226                label: Some("Person".to_string()),
1227                input: None,
1228            })),
1229        });
1230
1231        let cardinality = estimator.estimate(&filter);
1232        // OR: 0.01 + 0.01 - 0.0001 ≈ 0.0199, so 1000 * 0.0199 ≈ 19.9
1233        assert!(cardinality < 100.0);
1234        assert!(cardinality >= 1.0);
1235    }
1236
1237    #[test]
1238    fn test_filter_literal_true() {
1239        let mut estimator = CardinalityEstimator::new();
1240        estimator.add_table_stats("Person", TableStats::new(1000));
1241
1242        let filter = LogicalOperator::Filter(FilterOp {
1243            predicate: LogicalExpression::Literal(Value::Bool(true)),
1244            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1245                variable: "n".to_string(),
1246                label: Some("Person".to_string()),
1247                input: None,
1248            })),
1249        });
1250
1251        let cardinality = estimator.estimate(&filter);
1252        // Literal true has selectivity 1.0
1253        assert!((cardinality - 1000.0).abs() < 0.001);
1254    }
1255
1256    #[test]
1257    fn test_filter_literal_false() {
1258        let mut estimator = CardinalityEstimator::new();
1259        estimator.add_table_stats("Person", TableStats::new(1000));
1260
1261        let filter = LogicalOperator::Filter(FilterOp {
1262            predicate: LogicalExpression::Literal(Value::Bool(false)),
1263            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1264                variable: "n".to_string(),
1265                label: Some("Person".to_string()),
1266                input: None,
1267            })),
1268        });
1269
1270        let cardinality = estimator.estimate(&filter);
1271        // Literal false has selectivity 0.0, but min is 1.0
1272        assert!((cardinality - 1.0).abs() < 0.001);
1273    }
1274
1275    #[test]
1276    fn test_unary_not_selectivity() {
1277        let mut estimator = CardinalityEstimator::new();
1278        estimator.add_table_stats("Person", TableStats::new(1000));
1279
1280        let filter = LogicalOperator::Filter(FilterOp {
1281            predicate: LogicalExpression::Unary {
1282                op: UnaryOp::Not,
1283                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1284            },
1285            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1286                variable: "n".to_string(),
1287                label: Some("Person".to_string()),
1288                input: None,
1289            })),
1290        });
1291
1292        let cardinality = estimator.estimate(&filter);
1293        // NOT inverts selectivity
1294        assert!(cardinality < 1000.0);
1295    }
1296
1297    #[test]
1298    fn test_unary_is_null_selectivity() {
1299        let mut estimator = CardinalityEstimator::new();
1300        estimator.add_table_stats("Person", TableStats::new(1000));
1301
1302        let filter = LogicalOperator::Filter(FilterOp {
1303            predicate: LogicalExpression::Unary {
1304                op: UnaryOp::IsNull,
1305                operand: Box::new(LogicalExpression::Variable("x".to_string())),
1306            },
1307            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1308                variable: "n".to_string(),
1309                label: Some("Person".to_string()),
1310                input: None,
1311            })),
1312        });
1313
1314        let cardinality = estimator.estimate(&filter);
1315        // IS NULL has selectivity 0.05
1316        assert!(cardinality < 100.0);
1317    }
1318
1319    #[test]
1320    fn test_expand_cardinality() {
1321        let mut estimator = CardinalityEstimator::new();
1322        estimator.add_table_stats("Person", TableStats::new(100));
1323
1324        let expand = LogicalOperator::Expand(ExpandOp {
1325            from_variable: "a".to_string(),
1326            to_variable: "b".to_string(),
1327            edge_variable: None,
1328            direction: ExpandDirection::Outgoing,
1329            edge_type: None,
1330            min_hops: 1,
1331            max_hops: Some(1),
1332            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1333                variable: "a".to_string(),
1334                label: Some("Person".to_string()),
1335                input: None,
1336            })),
1337            path_alias: None,
1338        });
1339
1340        let cardinality = estimator.estimate(&expand);
1341        // Expand multiplies by fanout (10)
1342        assert!(cardinality > 100.0);
1343    }
1344
1345    #[test]
1346    fn test_expand_with_edge_type_filter() {
1347        let mut estimator = CardinalityEstimator::new();
1348        estimator.add_table_stats("Person", TableStats::new(100));
1349
1350        let expand = LogicalOperator::Expand(ExpandOp {
1351            from_variable: "a".to_string(),
1352            to_variable: "b".to_string(),
1353            edge_variable: None,
1354            direction: ExpandDirection::Outgoing,
1355            edge_type: Some("KNOWS".to_string()),
1356            min_hops: 1,
1357            max_hops: Some(1),
1358            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1359                variable: "a".to_string(),
1360                label: Some("Person".to_string()),
1361                input: None,
1362            })),
1363            path_alias: None,
1364        });
1365
1366        let cardinality = estimator.estimate(&expand);
1367        // With edge type, fanout is reduced by half
1368        assert!(cardinality > 100.0);
1369    }
1370
1371    #[test]
1372    fn test_expand_variable_length() {
1373        let mut estimator = CardinalityEstimator::new();
1374        estimator.add_table_stats("Person", TableStats::new(100));
1375
1376        let expand = LogicalOperator::Expand(ExpandOp {
1377            from_variable: "a".to_string(),
1378            to_variable: "b".to_string(),
1379            edge_variable: None,
1380            direction: ExpandDirection::Outgoing,
1381            edge_type: None,
1382            min_hops: 1,
1383            max_hops: Some(3),
1384            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1385                variable: "a".to_string(),
1386                label: Some("Person".to_string()),
1387                input: None,
1388            })),
1389            path_alias: None,
1390        });
1391
1392        let cardinality = estimator.estimate(&expand);
1393        // Variable length path has much higher cardinality
1394        assert!(cardinality > 500.0);
1395    }
1396
1397    #[test]
1398    fn test_join_cross_product() {
1399        let mut estimator = CardinalityEstimator::new();
1400        estimator.add_table_stats("Person", TableStats::new(100));
1401        estimator.add_table_stats("Company", TableStats::new(50));
1402
1403        let join = LogicalOperator::Join(JoinOp {
1404            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1405                variable: "p".to_string(),
1406                label: Some("Person".to_string()),
1407                input: None,
1408            })),
1409            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1410                variable: "c".to_string(),
1411                label: Some("Company".to_string()),
1412                input: None,
1413            })),
1414            join_type: JoinType::Cross,
1415            conditions: vec![],
1416        });
1417
1418        let cardinality = estimator.estimate(&join);
1419        // Cross join = 100 * 50 = 5000
1420        assert!((cardinality - 5000.0).abs() < 0.001);
1421    }
1422
1423    #[test]
1424    fn test_join_left_outer() {
1425        let mut estimator = CardinalityEstimator::new();
1426        estimator.add_table_stats("Person", TableStats::new(1000));
1427        estimator.add_table_stats("Company", TableStats::new(10));
1428
1429        let join = LogicalOperator::Join(JoinOp {
1430            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1431                variable: "p".to_string(),
1432                label: Some("Person".to_string()),
1433                input: None,
1434            })),
1435            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1436                variable: "c".to_string(),
1437                label: Some("Company".to_string()),
1438                input: None,
1439            })),
1440            join_type: JoinType::Left,
1441            conditions: vec![JoinCondition {
1442                left: LogicalExpression::Variable("p".to_string()),
1443                right: LogicalExpression::Variable("c".to_string()),
1444            }],
1445        });
1446
1447        let cardinality = estimator.estimate(&join);
1448        // Left join returns at least all left rows
1449        assert!(cardinality >= 1000.0);
1450    }
1451
1452    #[test]
1453    fn test_join_semi() {
1454        let mut estimator = CardinalityEstimator::new();
1455        estimator.add_table_stats("Person", TableStats::new(1000));
1456        estimator.add_table_stats("Company", TableStats::new(100));
1457
1458        let join = LogicalOperator::Join(JoinOp {
1459            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1460                variable: "p".to_string(),
1461                label: Some("Person".to_string()),
1462                input: None,
1463            })),
1464            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1465                variable: "c".to_string(),
1466                label: Some("Company".to_string()),
1467                input: None,
1468            })),
1469            join_type: JoinType::Semi,
1470            conditions: vec![],
1471        });
1472
1473        let cardinality = estimator.estimate(&join);
1474        // Semi join returns at most left cardinality
1475        assert!(cardinality <= 1000.0);
1476    }
1477
1478    #[test]
1479    fn test_join_anti() {
1480        let mut estimator = CardinalityEstimator::new();
1481        estimator.add_table_stats("Person", TableStats::new(1000));
1482        estimator.add_table_stats("Company", TableStats::new(100));
1483
1484        let join = LogicalOperator::Join(JoinOp {
1485            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1486                variable: "p".to_string(),
1487                label: Some("Person".to_string()),
1488                input: None,
1489            })),
1490            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1491                variable: "c".to_string(),
1492                label: Some("Company".to_string()),
1493                input: None,
1494            })),
1495            join_type: JoinType::Anti,
1496            conditions: vec![],
1497        });
1498
1499        let cardinality = estimator.estimate(&join);
1500        // Anti join returns at most left cardinality
1501        assert!(cardinality <= 1000.0);
1502        assert!(cardinality >= 1.0);
1503    }
1504
1505    #[test]
1506    fn test_project_preserves_cardinality() {
1507        let mut estimator = CardinalityEstimator::new();
1508        estimator.add_table_stats("Person", TableStats::new(1000));
1509
1510        let project = LogicalOperator::Project(ProjectOp {
1511            projections: vec![Projection {
1512                expression: LogicalExpression::Variable("n".to_string()),
1513                alias: None,
1514            }],
1515            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1516                variable: "n".to_string(),
1517                label: Some("Person".to_string()),
1518                input: None,
1519            })),
1520        });
1521
1522        let cardinality = estimator.estimate(&project);
1523        assert!((cardinality - 1000.0).abs() < 0.001);
1524    }
1525
1526    #[test]
1527    fn test_sort_preserves_cardinality() {
1528        let mut estimator = CardinalityEstimator::new();
1529        estimator.add_table_stats("Person", TableStats::new(1000));
1530
1531        let sort = LogicalOperator::Sort(SortOp {
1532            keys: vec![SortKey {
1533                expression: LogicalExpression::Variable("n".to_string()),
1534                order: SortOrder::Ascending,
1535            }],
1536            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1537                variable: "n".to_string(),
1538                label: Some("Person".to_string()),
1539                input: None,
1540            })),
1541        });
1542
1543        let cardinality = estimator.estimate(&sort);
1544        assert!((cardinality - 1000.0).abs() < 0.001);
1545    }
1546
1547    #[test]
1548    fn test_distinct_reduces_cardinality() {
1549        let mut estimator = CardinalityEstimator::new();
1550        estimator.add_table_stats("Person", TableStats::new(1000));
1551
1552        let distinct = LogicalOperator::Distinct(DistinctOp {
1553            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1554                variable: "n".to_string(),
1555                label: Some("Person".to_string()),
1556                input: None,
1557            })),
1558            columns: None,
1559        });
1560
1561        let cardinality = estimator.estimate(&distinct);
1562        // Distinct assumes 50% unique
1563        assert!((cardinality - 500.0).abs() < 0.001);
1564    }
1565
1566    #[test]
1567    fn test_skip_reduces_cardinality() {
1568        let mut estimator = CardinalityEstimator::new();
1569        estimator.add_table_stats("Person", TableStats::new(1000));
1570
1571        let skip = LogicalOperator::Skip(SkipOp {
1572            count: 100,
1573            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1574                variable: "n".to_string(),
1575                label: Some("Person".to_string()),
1576                input: None,
1577            })),
1578        });
1579
1580        let cardinality = estimator.estimate(&skip);
1581        assert!((cardinality - 900.0).abs() < 0.001);
1582    }
1583
1584    #[test]
1585    fn test_return_preserves_cardinality() {
1586        let mut estimator = CardinalityEstimator::new();
1587        estimator.add_table_stats("Person", TableStats::new(1000));
1588
1589        let ret = LogicalOperator::Return(ReturnOp {
1590            items: vec![ReturnItem {
1591                expression: LogicalExpression::Variable("n".to_string()),
1592                alias: None,
1593            }],
1594            distinct: false,
1595            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1596                variable: "n".to_string(),
1597                label: Some("Person".to_string()),
1598                input: None,
1599            })),
1600        });
1601
1602        let cardinality = estimator.estimate(&ret);
1603        assert!((cardinality - 1000.0).abs() < 0.001);
1604    }
1605
1606    #[test]
1607    fn test_empty_cardinality() {
1608        let estimator = CardinalityEstimator::new();
1609        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1610        assert!((cardinality).abs() < 0.001);
1611    }
1612
1613    #[test]
1614    fn test_table_stats_with_column() {
1615        let stats = TableStats::new(1000).with_column(
1616            "age",
1617            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1618        );
1619
1620        assert_eq!(stats.row_count, 1000);
1621        let col = stats.columns.get("age").unwrap();
1622        assert_eq!(col.distinct_count, 50);
1623        assert_eq!(col.null_count, 10);
1624        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1625        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1626    }
1627
1628    #[test]
1629    fn test_estimator_default() {
1630        let estimator = CardinalityEstimator::default();
1631        let scan = LogicalOperator::NodeScan(NodeScanOp {
1632            variable: "n".to_string(),
1633            label: None,
1634            input: None,
1635        });
1636        let cardinality = estimator.estimate(&scan);
1637        assert!((cardinality - 1000.0).abs() < 0.001);
1638    }
1639
1640    #[test]
1641    fn test_set_avg_fanout() {
1642        let mut estimator = CardinalityEstimator::new();
1643        estimator.add_table_stats("Person", TableStats::new(100));
1644        estimator.set_avg_fanout(5.0);
1645
1646        let expand = LogicalOperator::Expand(ExpandOp {
1647            from_variable: "a".to_string(),
1648            to_variable: "b".to_string(),
1649            edge_variable: None,
1650            direction: ExpandDirection::Outgoing,
1651            edge_type: None,
1652            min_hops: 1,
1653            max_hops: Some(1),
1654            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1655                variable: "a".to_string(),
1656                label: Some("Person".to_string()),
1657                input: None,
1658            })),
1659            path_alias: None,
1660        });
1661
1662        let cardinality = estimator.estimate(&expand);
1663        // With fanout 5: 100 * 5 = 500
1664        assert!((cardinality - 500.0).abs() < 0.001);
1665    }
1666
1667    #[test]
1668    fn test_multiple_group_by_keys_reduce_cardinality() {
1669        // The current implementation uses a simplified model where more group by keys
1670        // results in greater reduction (dividing by 10^num_keys). This is a simplification
1671        // that works for most cases where group by keys are correlated.
1672        let mut estimator = CardinalityEstimator::new();
1673        estimator.add_table_stats("Person", TableStats::new(10000));
1674
1675        let single_group = LogicalOperator::Aggregate(AggregateOp {
1676            group_by: vec![LogicalExpression::Property {
1677                variable: "n".to_string(),
1678                property: "city".to_string(),
1679            }],
1680            aggregates: vec![],
1681            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1682                variable: "n".to_string(),
1683                label: Some("Person".to_string()),
1684                input: None,
1685            })),
1686            having: None,
1687        });
1688
1689        let multi_group = LogicalOperator::Aggregate(AggregateOp {
1690            group_by: vec![
1691                LogicalExpression::Property {
1692                    variable: "n".to_string(),
1693                    property: "city".to_string(),
1694                },
1695                LogicalExpression::Property {
1696                    variable: "n".to_string(),
1697                    property: "country".to_string(),
1698                },
1699            ],
1700            aggregates: vec![],
1701            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1702                variable: "n".to_string(),
1703                label: Some("Person".to_string()),
1704                input: None,
1705            })),
1706            having: None,
1707        });
1708
1709        let single_card = estimator.estimate(&single_group);
1710        let multi_card = estimator.estimate(&multi_group);
1711
1712        // Both should reduce cardinality from input
1713        assert!(single_card < 10000.0);
1714        assert!(multi_card < 10000.0);
1715        // Both should be at least 1
1716        assert!(single_card >= 1.0);
1717        assert!(multi_card >= 1.0);
1718    }
1719
1720    // ============= Histogram Tests =============
1721
1722    #[test]
1723    fn test_histogram_build_uniform() {
1724        // Build histogram from uniformly distributed data
1725        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1726        let histogram = EquiDepthHistogram::build(&values, 10);
1727
1728        assert_eq!(histogram.num_buckets(), 10);
1729        assert_eq!(histogram.total_rows(), 100);
1730
1731        // Each bucket should have approximately 10 rows
1732        for bucket in histogram.buckets() {
1733            assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1734        }
1735    }
1736
1737    #[test]
1738    fn test_histogram_build_skewed() {
1739        // Build histogram from skewed data (many small values, few large)
1740        let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1741        values.extend((0..20).map(|i| 1000.0 + i as f64));
1742        let histogram = EquiDepthHistogram::build(&values, 5);
1743
1744        assert_eq!(histogram.num_buckets(), 5);
1745        assert_eq!(histogram.total_rows(), 100);
1746
1747        // Each bucket should have ~20 rows despite skewed data
1748        for bucket in histogram.buckets() {
1749            assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1750        }
1751    }
1752
1753    #[test]
1754    fn test_histogram_range_selectivity_full() {
1755        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1756        let histogram = EquiDepthHistogram::build(&values, 10);
1757
1758        // Full range should have selectivity ~1.0
1759        let selectivity = histogram.range_selectivity(None, None);
1760        assert!((selectivity - 1.0).abs() < 0.01);
1761    }
1762
1763    #[test]
1764    fn test_histogram_range_selectivity_half() {
1765        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1766        let histogram = EquiDepthHistogram::build(&values, 10);
1767
1768        // Values >= 50 should be ~50% (half the data)
1769        let selectivity = histogram.range_selectivity(Some(50.0), None);
1770        assert!(selectivity > 0.4 && selectivity < 0.6);
1771    }
1772
1773    #[test]
1774    fn test_histogram_range_selectivity_quarter() {
1775        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1776        let histogram = EquiDepthHistogram::build(&values, 10);
1777
1778        // Values < 25 should be ~25%
1779        let selectivity = histogram.range_selectivity(None, Some(25.0));
1780        assert!(selectivity > 0.2 && selectivity < 0.3);
1781    }
1782
1783    #[test]
1784    fn test_histogram_equality_selectivity() {
1785        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1786        let histogram = EquiDepthHistogram::build(&values, 10);
1787
1788        // Equality on 100 distinct values should be ~1%
1789        let selectivity = histogram.equality_selectivity(50.0);
1790        assert!(selectivity > 0.005 && selectivity < 0.02);
1791    }
1792
1793    #[test]
1794    fn test_histogram_empty() {
1795        let histogram = EquiDepthHistogram::build(&[], 10);
1796
1797        assert_eq!(histogram.num_buckets(), 0);
1798        assert_eq!(histogram.total_rows(), 0);
1799
1800        // Default selectivity for empty histogram
1801        let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1802        assert!((selectivity - 0.33).abs() < 0.01);
1803    }
1804
1805    #[test]
1806    fn test_histogram_bucket_overlap() {
1807        let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1808
1809        // Full overlap
1810        assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1811
1812        // Half overlap (lower half)
1813        assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1814
1815        // Half overlap (upper half)
1816        assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1817
1818        // No overlap (below)
1819        assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1820
1821        // No overlap (above)
1822        assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1823    }
1824
1825    #[test]
1826    fn test_column_stats_from_values() {
1827        let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
1828        let stats = ColumnStats::from_values(values, 4);
1829
1830        assert_eq!(stats.distinct_count, 5); // 10, 20, 30, 40, 50
1831        assert!(stats.min_value.is_some());
1832        assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
1833        assert!(stats.max_value.is_some());
1834        assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
1835        assert!(stats.histogram.is_some());
1836    }
1837
1838    #[test]
1839    fn test_column_stats_with_histogram_builder() {
1840        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1841        let histogram = EquiDepthHistogram::build(&values, 10);
1842
1843        let stats = ColumnStats::new(100)
1844            .with_range(0.0, 99.0)
1845            .with_histogram(histogram);
1846
1847        assert!(stats.histogram.is_some());
1848        assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
1849    }
1850
1851    #[test]
1852    fn test_filter_with_histogram_stats() {
1853        let mut estimator = CardinalityEstimator::new();
1854
1855        // Create stats with histogram for age column
1856        let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
1857        let histogram = EquiDepthHistogram::build(&age_values, 10);
1858        let age_stats = ColumnStats::new(62)
1859            .with_range(18.0, 79.0)
1860            .with_histogram(histogram);
1861
1862        estimator.add_table_stats(
1863            "Person",
1864            TableStats::new(1000).with_column("age", age_stats),
1865        );
1866
1867        // Filter: age > 50
1868        // Age range is 18-79, so >50 is about (79-50)/(79-18) = 29/61 ≈ 47.5%
1869        let filter = LogicalOperator::Filter(FilterOp {
1870            predicate: LogicalExpression::Binary {
1871                left: Box::new(LogicalExpression::Property {
1872                    variable: "n".to_string(),
1873                    property: "age".to_string(),
1874                }),
1875                op: BinaryOp::Gt,
1876                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1877            },
1878            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1879                variable: "n".to_string(),
1880                label: Some("Person".to_string()),
1881                input: None,
1882            })),
1883        });
1884
1885        let cardinality = estimator.estimate(&filter);
1886
1887        // With histogram, should get more accurate estimate than default 0.33
1888        // Expected: ~47.5% of 1000 = ~475
1889        assert!(cardinality > 300.0 && cardinality < 600.0);
1890    }
1891
1892    #[test]
1893    fn test_filter_equality_with_histogram() {
1894        let mut estimator = CardinalityEstimator::new();
1895
1896        // Create stats with histogram
1897        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1898        let histogram = EquiDepthHistogram::build(&values, 10);
1899        let stats = ColumnStats::new(100)
1900            .with_range(0.0, 99.0)
1901            .with_histogram(histogram);
1902
1903        estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
1904
1905        // Filter: value = 50
1906        let filter = LogicalOperator::Filter(FilterOp {
1907            predicate: LogicalExpression::Binary {
1908                left: Box::new(LogicalExpression::Property {
1909                    variable: "d".to_string(),
1910                    property: "value".to_string(),
1911                }),
1912                op: BinaryOp::Eq,
1913                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1914            },
1915            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1916                variable: "d".to_string(),
1917                label: Some("Data".to_string()),
1918                input: None,
1919            })),
1920        });
1921
1922        let cardinality = estimator.estimate(&filter);
1923
1924        // With 100 distinct values, selectivity should be ~1%
1925        // 1000 * 0.01 = 10
1926        assert!(cardinality >= 1.0 && cardinality < 50.0);
1927    }
1928
1929    #[test]
1930    fn test_histogram_min_max() {
1931        let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
1932        let histogram = EquiDepthHistogram::build(&values, 2);
1933
1934        assert_eq!(histogram.min_value(), Some(5.0));
1935        // Max is the upper bound of the last bucket
1936        assert!(histogram.max_value().is_some());
1937    }
1938}