Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4//!
5//! # Equi-Depth Histograms
6//!
7//! This module provides equi-depth histogram support for accurate selectivity
8//! estimation of range predicates. Unlike equi-width histograms that divide
9//! the value range into equal-sized buckets, equi-depth histograms divide
10//! the data into buckets with approximately equal numbers of rows.
11//!
12//! Benefits:
13//! - Better estimates for skewed data distributions
14//! - More accurate range selectivity than assuming uniform distribution
15//! - Adaptive to actual data characteristics
16
17use crate::query::plan::{
18    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19    LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
20};
21use std::collections::HashMap;
22
23/// A bucket in an equi-depth histogram.
24///
25/// Each bucket represents a range of values and the frequency of rows
26/// falling within that range. In an equi-depth histogram, all buckets
27/// contain approximately the same number of rows.
28#[derive(Debug, Clone)]
29pub struct HistogramBucket {
30    /// Lower bound of the bucket (inclusive).
31    pub lower_bound: f64,
32    /// Upper bound of the bucket (exclusive, except for the last bucket).
33    pub upper_bound: f64,
34    /// Number of rows in this bucket.
35    pub frequency: u64,
36    /// Number of distinct values in this bucket.
37    pub distinct_count: u64,
38}
39
40impl HistogramBucket {
41    /// Creates a new histogram bucket.
42    #[must_use]
43    pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
44        Self {
45            lower_bound,
46            upper_bound,
47            frequency,
48            distinct_count,
49        }
50    }
51
52    /// Returns the width of this bucket.
53    #[must_use]
54    pub fn width(&self) -> f64 {
55        self.upper_bound - self.lower_bound
56    }
57
58    /// Checks if a value falls within this bucket.
59    #[must_use]
60    pub fn contains(&self, value: f64) -> bool {
61        value >= self.lower_bound && value < self.upper_bound
62    }
63
64    /// Returns the fraction of this bucket covered by the given range.
65    #[must_use]
66    pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
67        let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
68        let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
69
70        let bucket_width = self.width();
71        if bucket_width <= 0.0 {
72            return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
73                1.0
74            } else {
75                0.0
76            };
77        }
78
79        let overlap = (effective_upper - effective_lower).max(0.0);
80        (overlap / bucket_width).min(1.0)
81    }
82}
83
84/// An equi-depth histogram for selectivity estimation.
85///
86/// Equi-depth histograms partition data into buckets where each bucket
87/// contains approximately the same number of rows. This provides more
88/// accurate selectivity estimates than assuming uniform distribution,
89/// especially for skewed data.
90///
91/// # Example
92///
93/// ```ignore
94/// use grafeo_engine::query::optimizer::cardinality::EquiDepthHistogram;
95///
96/// // Build a histogram from sorted values
97/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0];
98/// let histogram = EquiDepthHistogram::build(&values, 4);
99///
100/// // Estimate selectivity for age > 25
101/// let selectivity = histogram.range_selectivity(Some(25.0), None);
102/// ```
103#[derive(Debug, Clone)]
104pub struct EquiDepthHistogram {
105    /// The histogram buckets, sorted by lower_bound.
106    buckets: Vec<HistogramBucket>,
107    /// Total number of rows represented by this histogram.
108    total_rows: u64,
109}
110
111impl EquiDepthHistogram {
112    /// Creates a new histogram from pre-built buckets.
113    #[must_use]
114    pub fn new(buckets: Vec<HistogramBucket>) -> Self {
115        let total_rows = buckets.iter().map(|b| b.frequency).sum();
116        Self {
117            buckets,
118            total_rows,
119        }
120    }
121
122    /// Builds an equi-depth histogram from a sorted slice of values.
123    ///
124    /// # Arguments
125    /// * `values` - A sorted slice of numeric values
126    /// * `num_buckets` - The desired number of buckets
127    ///
128    /// # Returns
129    /// An equi-depth histogram with approximately equal row counts per bucket.
130    #[must_use]
131    pub fn build(values: &[f64], num_buckets: usize) -> Self {
132        if values.is_empty() || num_buckets == 0 {
133            return Self {
134                buckets: Vec::new(),
135                total_rows: 0,
136            };
137        }
138
139        let num_buckets = num_buckets.min(values.len());
140        let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
141        let mut buckets = Vec::with_capacity(num_buckets);
142
143        let mut start_idx = 0;
144        while start_idx < values.len() {
145            let end_idx = (start_idx + rows_per_bucket).min(values.len());
146            let lower_bound = values[start_idx];
147            let upper_bound = if end_idx < values.len() {
148                values[end_idx]
149            } else {
150                // For the last bucket, extend slightly beyond the max value
151                values[end_idx - 1] + 1.0
152            };
153
154            // Count distinct values in this bucket
155            let bucket_values = &values[start_idx..end_idx];
156            let distinct_count = count_distinct(bucket_values);
157
158            buckets.push(HistogramBucket::new(
159                lower_bound,
160                upper_bound,
161                (end_idx - start_idx) as u64,
162                distinct_count,
163            ));
164
165            start_idx = end_idx;
166        }
167
168        Self::new(buckets)
169    }
170
171    /// Returns the number of buckets in this histogram.
172    #[must_use]
173    pub fn num_buckets(&self) -> usize {
174        self.buckets.len()
175    }
176
177    /// Returns the total number of rows represented.
178    #[must_use]
179    pub fn total_rows(&self) -> u64 {
180        self.total_rows
181    }
182
183    /// Returns the histogram buckets.
184    #[must_use]
185    pub fn buckets(&self) -> &[HistogramBucket] {
186        &self.buckets
187    }
188
189    /// Estimates selectivity for a range predicate.
190    ///
191    /// # Arguments
192    /// * `lower` - Lower bound (None for unbounded)
193    /// * `upper` - Upper bound (None for unbounded)
194    ///
195    /// # Returns
196    /// Estimated fraction of rows matching the range (0.0 to 1.0).
197    #[must_use]
198    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
199        if self.buckets.is_empty() || self.total_rows == 0 {
200            return 0.33; // Default fallback
201        }
202
203        let mut matching_rows = 0.0;
204
205        for bucket in &self.buckets {
206            // Check if this bucket overlaps with the range
207            let bucket_lower = bucket.lower_bound;
208            let bucket_upper = bucket.upper_bound;
209
210            // Skip buckets entirely outside the range
211            if let Some(l) = lower {
212                if bucket_upper <= l {
213                    continue;
214                }
215            }
216            if let Some(u) = upper {
217                if bucket_lower >= u {
218                    continue;
219                }
220            }
221
222            // Calculate the fraction of this bucket covered by the range
223            let overlap = bucket.overlap_fraction(lower, upper);
224            matching_rows += overlap * bucket.frequency as f64;
225        }
226
227        (matching_rows / self.total_rows as f64).min(1.0).max(0.0)
228    }
229
230    /// Estimates selectivity for an equality predicate.
231    ///
232    /// Uses the distinct count within matching buckets for better accuracy.
233    #[must_use]
234    pub fn equality_selectivity(&self, value: f64) -> f64 {
235        if self.buckets.is_empty() || self.total_rows == 0 {
236            return 0.01; // Default fallback
237        }
238
239        // Find the bucket containing this value
240        for bucket in &self.buckets {
241            if bucket.contains(value) {
242                // Assume uniform distribution within bucket
243                if bucket.distinct_count > 0 {
244                    return (bucket.frequency as f64
245                        / bucket.distinct_count as f64
246                        / self.total_rows as f64)
247                        .min(1.0);
248                }
249            }
250        }
251
252        // Value not in any bucket - very low selectivity
253        0.001
254    }
255
256    /// Gets the minimum value in the histogram.
257    #[must_use]
258    pub fn min_value(&self) -> Option<f64> {
259        self.buckets.first().map(|b| b.lower_bound)
260    }
261
262    /// Gets the maximum value in the histogram.
263    #[must_use]
264    pub fn max_value(&self) -> Option<f64> {
265        self.buckets.last().map(|b| b.upper_bound)
266    }
267}
268
269/// Counts distinct values in a sorted slice.
270fn count_distinct(sorted_values: &[f64]) -> u64 {
271    if sorted_values.is_empty() {
272        return 0;
273    }
274
275    let mut count = 1u64;
276    let mut prev = sorted_values[0];
277
278    for &val in &sorted_values[1..] {
279        if (val - prev).abs() > f64::EPSILON {
280            count += 1;
281            prev = val;
282        }
283    }
284
285    count
286}
287
288/// Statistics for a table/label.
289#[derive(Debug, Clone)]
290pub struct TableStats {
291    /// Total number of rows.
292    pub row_count: u64,
293    /// Column statistics.
294    pub columns: HashMap<String, ColumnStats>,
295}
296
297impl TableStats {
298    /// Creates new table statistics.
299    #[must_use]
300    pub fn new(row_count: u64) -> Self {
301        Self {
302            row_count,
303            columns: HashMap::new(),
304        }
305    }
306
307    /// Adds column statistics.
308    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
309        self.columns.insert(name.to_string(), stats);
310        self
311    }
312}
313
314/// Statistics for a column.
315#[derive(Debug, Clone)]
316pub struct ColumnStats {
317    /// Number of distinct values.
318    pub distinct_count: u64,
319    /// Number of null values.
320    pub null_count: u64,
321    /// Minimum value (if orderable).
322    pub min_value: Option<f64>,
323    /// Maximum value (if orderable).
324    pub max_value: Option<f64>,
325    /// Equi-depth histogram for accurate selectivity estimation.
326    pub histogram: Option<EquiDepthHistogram>,
327}
328
329impl ColumnStats {
330    /// Creates new column statistics.
331    #[must_use]
332    pub fn new(distinct_count: u64) -> Self {
333        Self {
334            distinct_count,
335            null_count: 0,
336            min_value: None,
337            max_value: None,
338            histogram: None,
339        }
340    }
341
342    /// Sets the null count.
343    #[must_use]
344    pub fn with_nulls(mut self, null_count: u64) -> Self {
345        self.null_count = null_count;
346        self
347    }
348
349    /// Sets the min/max range.
350    #[must_use]
351    pub fn with_range(mut self, min: f64, max: f64) -> Self {
352        self.min_value = Some(min);
353        self.max_value = Some(max);
354        self
355    }
356
357    /// Sets the equi-depth histogram for this column.
358    #[must_use]
359    pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
360        self.histogram = Some(histogram);
361        self
362    }
363
364    /// Builds column statistics with histogram from raw values.
365    ///
366    /// This is a convenience method that computes all statistics from the data.
367    ///
368    /// # Arguments
369    /// * `values` - The column values (will be sorted internally)
370    /// * `num_buckets` - Number of histogram buckets to create
371    #[must_use]
372    pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
373        if values.is_empty() {
374            return Self::new(0);
375        }
376
377        // Sort values for histogram building
378        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
379
380        let min = values.first().copied();
381        let max = values.last().copied();
382        let distinct_count = count_distinct(&values);
383        let histogram = EquiDepthHistogram::build(&values, num_buckets);
384
385        Self {
386            distinct_count,
387            null_count: 0,
388            min_value: min,
389            max_value: max,
390            histogram: Some(histogram),
391        }
392    }
393}
394
395/// Cardinality estimator.
396pub struct CardinalityEstimator {
397    /// Statistics for each label/table.
398    table_stats: HashMap<String, TableStats>,
399    /// Default row count for unknown tables.
400    default_row_count: u64,
401    /// Default selectivity for unknown predicates.
402    default_selectivity: f64,
403    /// Average edge fanout (outgoing edges per node).
404    avg_fanout: f64,
405}
406
407impl CardinalityEstimator {
408    /// Creates a new cardinality estimator with default settings.
409    #[must_use]
410    pub fn new() -> Self {
411        Self {
412            table_stats: HashMap::new(),
413            default_row_count: 1000,
414            default_selectivity: 0.1,
415            avg_fanout: 10.0,
416        }
417    }
418
419    /// Adds statistics for a table/label.
420    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
421        self.table_stats.insert(name.to_string(), stats);
422    }
423
424    /// Sets the average edge fanout.
425    pub fn set_avg_fanout(&mut self, fanout: f64) {
426        self.avg_fanout = fanout;
427    }
428
429    /// Estimates the cardinality of a logical operator.
430    #[must_use]
431    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
432        match op {
433            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
434            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
435            LogicalOperator::Project(project) => self.estimate_project(project),
436            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
437            LogicalOperator::Join(join) => self.estimate_join(join),
438            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
439            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
440            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
441            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
442            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
443            LogicalOperator::Return(ret) => self.estimate(&ret.input),
444            LogicalOperator::Empty => 0.0,
445            _ => self.default_row_count as f64,
446        }
447    }
448
449    /// Estimates node scan cardinality.
450    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
451        if let Some(label) = &scan.label {
452            if let Some(stats) = self.table_stats.get(label) {
453                return stats.row_count as f64;
454            }
455        }
456        // No label filter - scan all nodes
457        self.default_row_count as f64
458    }
459
460    /// Estimates filter cardinality.
461    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
462        let input_cardinality = self.estimate(&filter.input);
463        let selectivity = self.estimate_selectivity(&filter.predicate);
464        (input_cardinality * selectivity).max(1.0)
465    }
466
467    /// Estimates projection cardinality (same as input).
468    fn estimate_project(&self, project: &ProjectOp) -> f64 {
469        self.estimate(&project.input)
470    }
471
472    /// Estimates expand cardinality.
473    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
474        let input_cardinality = self.estimate(&expand.input);
475
476        // Apply fanout based on edge type
477        let fanout = if expand.edge_type.is_some() {
478            // Specific edge type typically has lower fanout
479            self.avg_fanout * 0.5
480        } else {
481            self.avg_fanout
482        };
483
484        // Handle variable-length paths
485        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
486            let min = expand.min_hops as f64;
487            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
488            // Geometric series approximation
489            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
490        } else {
491            fanout
492        };
493
494        (input_cardinality * path_multiplier).max(1.0)
495    }
496
497    /// Estimates join cardinality.
498    fn estimate_join(&self, join: &JoinOp) -> f64 {
499        let left_card = self.estimate(&join.left);
500        let right_card = self.estimate(&join.right);
501
502        match join.join_type {
503            JoinType::Cross => left_card * right_card,
504            JoinType::Inner => {
505                // Assume join selectivity based on conditions
506                let selectivity = if join.conditions.is_empty() {
507                    1.0 // Cross join
508                } else {
509                    // Estimate based on number of conditions
510                    0.1_f64.powi(join.conditions.len() as i32)
511                };
512                (left_card * right_card * selectivity).max(1.0)
513            }
514            JoinType::Left => {
515                // Left join returns at least all left rows
516                let inner_card = self.estimate_join(&JoinOp {
517                    left: join.left.clone(),
518                    right: join.right.clone(),
519                    join_type: JoinType::Inner,
520                    conditions: join.conditions.clone(),
521                });
522                inner_card.max(left_card)
523            }
524            JoinType::Right => {
525                // Right join returns at least all right rows
526                let inner_card = self.estimate_join(&JoinOp {
527                    left: join.left.clone(),
528                    right: join.right.clone(),
529                    join_type: JoinType::Inner,
530                    conditions: join.conditions.clone(),
531                });
532                inner_card.max(right_card)
533            }
534            JoinType::Full => {
535                // Full join returns at least max(left, right)
536                let inner_card = self.estimate_join(&JoinOp {
537                    left: join.left.clone(),
538                    right: join.right.clone(),
539                    join_type: JoinType::Inner,
540                    conditions: join.conditions.clone(),
541                });
542                inner_card.max(left_card.max(right_card))
543            }
544            JoinType::Semi => {
545                // Semi join returns at most left cardinality
546                (left_card * self.default_selectivity).max(1.0)
547            }
548            JoinType::Anti => {
549                // Anti join returns at most left cardinality
550                (left_card * (1.0 - self.default_selectivity)).max(1.0)
551            }
552        }
553    }
554
555    /// Estimates aggregation cardinality.
556    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
557        let input_cardinality = self.estimate(&agg.input);
558
559        if agg.group_by.is_empty() {
560            // Global aggregation - single row
561            1.0
562        } else {
563            // Group by - estimate distinct groups
564            // Assume each group key reduces cardinality by 10
565            let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
566            (input_cardinality / group_reduction).max(1.0)
567        }
568    }
569
570    /// Estimates sort cardinality (same as input).
571    fn estimate_sort(&self, sort: &SortOp) -> f64 {
572        self.estimate(&sort.input)
573    }
574
575    /// Estimates distinct cardinality.
576    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
577        let input_cardinality = self.estimate(&distinct.input);
578        // Assume 50% distinct by default
579        (input_cardinality * 0.5).max(1.0)
580    }
581
582    /// Estimates limit cardinality.
583    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
584        let input_cardinality = self.estimate(&limit.input);
585        (limit.count as f64).min(input_cardinality)
586    }
587
588    /// Estimates skip cardinality.
589    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
590        let input_cardinality = self.estimate(&skip.input);
591        (input_cardinality - skip.count as f64).max(0.0)
592    }
593
594    /// Estimates the selectivity of a predicate (0.0 to 1.0).
595    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
596        match expr {
597            LogicalExpression::Binary { left, op, right } => {
598                self.estimate_binary_selectivity(left, *op, right)
599            }
600            LogicalExpression::Unary { op, operand } => {
601                self.estimate_unary_selectivity(*op, operand)
602            }
603            LogicalExpression::Literal(value) => {
604                // Boolean literal
605                if let grafeo_common::types::Value::Bool(b) = value {
606                    if *b { 1.0 } else { 0.0 }
607                } else {
608                    self.default_selectivity
609                }
610            }
611            _ => self.default_selectivity,
612        }
613    }
614
615    /// Estimates binary expression selectivity.
616    fn estimate_binary_selectivity(
617        &self,
618        left: &LogicalExpression,
619        op: BinaryOp,
620        right: &LogicalExpression,
621    ) -> f64 {
622        match op {
623            // Equality - try histogram-based estimation
624            BinaryOp::Eq => {
625                if let Some(selectivity) = self.try_equality_selectivity(left, right) {
626                    return selectivity;
627                }
628                0.01
629            }
630            // Inequality is very unselective
631            BinaryOp::Ne => 0.99,
632            // Range predicates - use histogram if available
633            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
634                if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
635                    return selectivity;
636                }
637                0.33
638            }
639            // Logical operators - recursively estimate sub-expressions
640            BinaryOp::And => {
641                let left_sel = self.estimate_selectivity(left);
642                let right_sel = self.estimate_selectivity(right);
643                // AND reduces selectivity (multiply assuming independence)
644                left_sel * right_sel
645            }
646            BinaryOp::Or => {
647                let left_sel = self.estimate_selectivity(left);
648                let right_sel = self.estimate_selectivity(right);
649                // OR: P(A ∪ B) = P(A) + P(B) - P(A ∩ B)
650                // Assuming independence: P(A ∩ B) = P(A) * P(B)
651                (left_sel + right_sel - left_sel * right_sel).min(1.0)
652            }
653            // String operations
654            BinaryOp::StartsWith => 0.1,
655            BinaryOp::EndsWith => 0.1,
656            BinaryOp::Contains => 0.1,
657            BinaryOp::Like => 0.1,
658            // Collection membership
659            BinaryOp::In => 0.1,
660            // Other operations
661            _ => self.default_selectivity,
662        }
663    }
664
665    /// Tries to estimate equality selectivity using histograms.
666    fn try_equality_selectivity(
667        &self,
668        left: &LogicalExpression,
669        right: &LogicalExpression,
670    ) -> Option<f64> {
671        // Extract property access and literal value
672        let (label, column, value) = self.extract_column_and_value(left, right)?;
673
674        // Get column stats with histogram
675        let stats = self.get_column_stats(&label, &column)?;
676
677        // Try histogram-based estimation
678        if let Some(ref histogram) = stats.histogram {
679            return Some(histogram.equality_selectivity(value));
680        }
681
682        // Fall back to distinct count estimation
683        if stats.distinct_count > 0 {
684            return Some(1.0 / stats.distinct_count as f64);
685        }
686
687        None
688    }
689
690    /// Tries to estimate range selectivity using histograms.
691    fn try_range_selectivity(
692        &self,
693        left: &LogicalExpression,
694        op: BinaryOp,
695        right: &LogicalExpression,
696    ) -> Option<f64> {
697        // Extract property access and literal value
698        let (label, column, value) = self.extract_column_and_value(left, right)?;
699
700        // Get column stats
701        let stats = self.get_column_stats(&label, &column)?;
702
703        // Determine the range based on operator
704        let (lower, upper) = match op {
705            BinaryOp::Lt => (None, Some(value)),
706            BinaryOp::Le => (None, Some(value + f64::EPSILON)),
707            BinaryOp::Gt => (Some(value + f64::EPSILON), None),
708            BinaryOp::Ge => (Some(value), None),
709            _ => return None,
710        };
711
712        // Try histogram-based estimation first
713        if let Some(ref histogram) = stats.histogram {
714            return Some(histogram.range_selectivity(lower, upper));
715        }
716
717        // Fall back to min/max range estimation
718        if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
719            let range = max - min;
720            if range <= 0.0 {
721                return Some(1.0);
722            }
723
724            let effective_lower = lower.unwrap_or(min).max(min);
725            let effective_upper = upper.unwrap_or(max).min(max);
726            let overlap = (effective_upper - effective_lower).max(0.0);
727            return Some((overlap / range).min(1.0).max(0.0));
728        }
729
730        None
731    }
732
733    /// Extracts column information and literal value from a comparison.
734    ///
735    /// Returns (label, column_name, numeric_value) if the expression is
736    /// a comparison between a property access and a numeric literal.
737    fn extract_column_and_value(
738        &self,
739        left: &LogicalExpression,
740        right: &LogicalExpression,
741    ) -> Option<(String, String, f64)> {
742        // Try left as property, right as literal
743        if let Some(result) = self.try_extract_property_literal(left, right) {
744            return Some(result);
745        }
746
747        // Try right as property, left as literal
748        self.try_extract_property_literal(right, left)
749    }
750
751    /// Tries to extract property and literal from a specific ordering.
752    fn try_extract_property_literal(
753        &self,
754        property_expr: &LogicalExpression,
755        literal_expr: &LogicalExpression,
756    ) -> Option<(String, String, f64)> {
757        // Extract property access
758        let (variable, property) = match property_expr {
759            LogicalExpression::Property { variable, property } => {
760                (variable.clone(), property.clone())
761            }
762            _ => return None,
763        };
764
765        // Extract numeric literal
766        let value = match literal_expr {
767            LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
768            LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
769            _ => return None,
770        };
771
772        // Try to find a label for this variable from table stats
773        // Use the variable name as a heuristic label lookup
774        // In practice, the optimizer would track which labels variables are bound to
775        for label in self.table_stats.keys() {
776            if let Some(stats) = self.table_stats.get(label) {
777                if stats.columns.contains_key(&property) {
778                    return Some((label.clone(), property, value));
779                }
780            }
781        }
782
783        // If no stats found but we have the property, return with variable as label
784        Some((variable, property, value))
785    }
786
787    /// Estimates unary expression selectivity.
788    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
789        match op {
790            UnaryOp::Not => 1.0 - self.default_selectivity,
791            UnaryOp::IsNull => 0.05, // Assume 5% nulls
792            UnaryOp::IsNotNull => 0.95,
793            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
794        }
795    }
796
797    /// Gets statistics for a column.
798    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
799        self.table_stats.get(label)?.columns.get(column)
800    }
801
802    /// Estimates equality selectivity using column statistics.
803    #[allow(dead_code)]
804    fn estimate_equality_with_stats(&self, label: &str, column: &str) -> f64 {
805        if let Some(stats) = self.get_column_stats(label, column) {
806            if stats.distinct_count > 0 {
807                return 1.0 / stats.distinct_count as f64;
808            }
809        }
810        0.01 // Default for equality
811    }
812
813    /// Estimates range selectivity using column statistics.
814    #[allow(dead_code)]
815    fn estimate_range_with_stats(
816        &self,
817        label: &str,
818        column: &str,
819        lower: Option<f64>,
820        upper: Option<f64>,
821    ) -> f64 {
822        if let Some(stats) = self.get_column_stats(label, column) {
823            if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
824                let range = max - min;
825                if range <= 0.0 {
826                    return 1.0;
827                }
828
829                let effective_lower = lower.unwrap_or(min).max(min);
830                let effective_upper = upper.unwrap_or(max).min(max);
831
832                let overlap = (effective_upper - effective_lower).max(0.0);
833                return (overlap / range).min(1.0).max(0.0);
834            }
835        }
836        0.33 // Default for range
837    }
838}
839
840impl Default for CardinalityEstimator {
841    fn default() -> Self {
842        Self::new()
843    }
844}
845
846#[cfg(test)]
847mod tests {
848    use super::*;
849    use crate::query::plan::{
850        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, ProjectOp,
851        Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
852    };
853    use grafeo_common::types::Value;
854
855    #[test]
856    fn test_node_scan_with_stats() {
857        let mut estimator = CardinalityEstimator::new();
858        estimator.add_table_stats("Person", TableStats::new(5000));
859
860        let scan = LogicalOperator::NodeScan(NodeScanOp {
861            variable: "n".to_string(),
862            label: Some("Person".to_string()),
863            input: None,
864        });
865
866        let cardinality = estimator.estimate(&scan);
867        assert!((cardinality - 5000.0).abs() < 0.001);
868    }
869
870    #[test]
871    fn test_filter_reduces_cardinality() {
872        let mut estimator = CardinalityEstimator::new();
873        estimator.add_table_stats("Person", TableStats::new(1000));
874
875        let filter = LogicalOperator::Filter(FilterOp {
876            predicate: LogicalExpression::Binary {
877                left: Box::new(LogicalExpression::Property {
878                    variable: "n".to_string(),
879                    property: "age".to_string(),
880                }),
881                op: BinaryOp::Eq,
882                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
883            },
884            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
885                variable: "n".to_string(),
886                label: Some("Person".to_string()),
887                input: None,
888            })),
889        });
890
891        let cardinality = estimator.estimate(&filter);
892        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
893        assert!(cardinality < 1000.0);
894        assert!(cardinality >= 1.0);
895    }
896
897    #[test]
898    fn test_join_cardinality() {
899        let mut estimator = CardinalityEstimator::new();
900        estimator.add_table_stats("Person", TableStats::new(1000));
901        estimator.add_table_stats("Company", TableStats::new(100));
902
903        let join = LogicalOperator::Join(JoinOp {
904            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
905                variable: "p".to_string(),
906                label: Some("Person".to_string()),
907                input: None,
908            })),
909            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
910                variable: "c".to_string(),
911                label: Some("Company".to_string()),
912                input: None,
913            })),
914            join_type: JoinType::Inner,
915            conditions: vec![JoinCondition {
916                left: LogicalExpression::Property {
917                    variable: "p".to_string(),
918                    property: "company_id".to_string(),
919                },
920                right: LogicalExpression::Property {
921                    variable: "c".to_string(),
922                    property: "id".to_string(),
923                },
924            }],
925        });
926
927        let cardinality = estimator.estimate(&join);
928        // Should be less than cross product
929        assert!(cardinality < 1000.0 * 100.0);
930    }
931
932    #[test]
933    fn test_limit_caps_cardinality() {
934        let mut estimator = CardinalityEstimator::new();
935        estimator.add_table_stats("Person", TableStats::new(1000));
936
937        let limit = LogicalOperator::Limit(LimitOp {
938            count: 10,
939            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
940                variable: "n".to_string(),
941                label: Some("Person".to_string()),
942                input: None,
943            })),
944        });
945
946        let cardinality = estimator.estimate(&limit);
947        assert!((cardinality - 10.0).abs() < 0.001);
948    }
949
950    #[test]
951    fn test_aggregate_reduces_cardinality() {
952        let mut estimator = CardinalityEstimator::new();
953        estimator.add_table_stats("Person", TableStats::new(1000));
954
955        // Global aggregation
956        let global_agg = LogicalOperator::Aggregate(AggregateOp {
957            group_by: vec![],
958            aggregates: vec![],
959            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
960                variable: "n".to_string(),
961                label: Some("Person".to_string()),
962                input: None,
963            })),
964            having: None,
965        });
966
967        let cardinality = estimator.estimate(&global_agg);
968        assert!((cardinality - 1.0).abs() < 0.001);
969
970        // Group by aggregation
971        let group_agg = LogicalOperator::Aggregate(AggregateOp {
972            group_by: vec![LogicalExpression::Property {
973                variable: "n".to_string(),
974                property: "city".to_string(),
975            }],
976            aggregates: vec![],
977            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
978                variable: "n".to_string(),
979                label: Some("Person".to_string()),
980                input: None,
981            })),
982            having: None,
983        });
984
985        let cardinality = estimator.estimate(&group_agg);
986        // Should be less than input
987        assert!(cardinality < 1000.0);
988    }
989
990    #[test]
991    fn test_node_scan_without_stats() {
992        let estimator = CardinalityEstimator::new();
993
994        let scan = LogicalOperator::NodeScan(NodeScanOp {
995            variable: "n".to_string(),
996            label: Some("Unknown".to_string()),
997            input: None,
998        });
999
1000        let cardinality = estimator.estimate(&scan);
1001        // Should return default (1000)
1002        assert!((cardinality - 1000.0).abs() < 0.001);
1003    }
1004
1005    #[test]
1006    fn test_node_scan_no_label() {
1007        let estimator = CardinalityEstimator::new();
1008
1009        let scan = LogicalOperator::NodeScan(NodeScanOp {
1010            variable: "n".to_string(),
1011            label: None,
1012            input: None,
1013        });
1014
1015        let cardinality = estimator.estimate(&scan);
1016        // Should scan all nodes (default)
1017        assert!((cardinality - 1000.0).abs() < 0.001);
1018    }
1019
1020    #[test]
1021    fn test_filter_inequality_selectivity() {
1022        let mut estimator = CardinalityEstimator::new();
1023        estimator.add_table_stats("Person", TableStats::new(1000));
1024
1025        let filter = LogicalOperator::Filter(FilterOp {
1026            predicate: LogicalExpression::Binary {
1027                left: Box::new(LogicalExpression::Property {
1028                    variable: "n".to_string(),
1029                    property: "age".to_string(),
1030                }),
1031                op: BinaryOp::Ne,
1032                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1033            },
1034            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1035                variable: "n".to_string(),
1036                label: Some("Person".to_string()),
1037                input: None,
1038            })),
1039        });
1040
1041        let cardinality = estimator.estimate(&filter);
1042        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
1043        assert!(cardinality > 900.0);
1044    }
1045
1046    #[test]
1047    fn test_filter_range_selectivity() {
1048        let mut estimator = CardinalityEstimator::new();
1049        estimator.add_table_stats("Person", TableStats::new(1000));
1050
1051        let filter = LogicalOperator::Filter(FilterOp {
1052            predicate: LogicalExpression::Binary {
1053                left: Box::new(LogicalExpression::Property {
1054                    variable: "n".to_string(),
1055                    property: "age".to_string(),
1056                }),
1057                op: BinaryOp::Gt,
1058                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1059            },
1060            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1061                variable: "n".to_string(),
1062                label: Some("Person".to_string()),
1063                input: None,
1064            })),
1065        });
1066
1067        let cardinality = estimator.estimate(&filter);
1068        // Range selectivity is 0.33, so 1000 * 0.33 = 330
1069        assert!(cardinality < 500.0);
1070        assert!(cardinality > 100.0);
1071    }
1072
1073    #[test]
1074    fn test_filter_and_selectivity() {
1075        let mut estimator = CardinalityEstimator::new();
1076        estimator.add_table_stats("Person", TableStats::new(1000));
1077
1078        // Test AND with two equality predicates
1079        // Each equality has selectivity 0.01, so AND gives 0.01 * 0.01 = 0.0001
1080        let filter = LogicalOperator::Filter(FilterOp {
1081            predicate: LogicalExpression::Binary {
1082                left: Box::new(LogicalExpression::Binary {
1083                    left: Box::new(LogicalExpression::Property {
1084                        variable: "n".to_string(),
1085                        property: "city".to_string(),
1086                    }),
1087                    op: BinaryOp::Eq,
1088                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1089                }),
1090                op: BinaryOp::And,
1091                right: Box::new(LogicalExpression::Binary {
1092                    left: Box::new(LogicalExpression::Property {
1093                        variable: "n".to_string(),
1094                        property: "age".to_string(),
1095                    }),
1096                    op: BinaryOp::Eq,
1097                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1098                }),
1099            },
1100            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1101                variable: "n".to_string(),
1102                label: Some("Person".to_string()),
1103                input: None,
1104            })),
1105        });
1106
1107        let cardinality = estimator.estimate(&filter);
1108        // AND reduces selectivity (multiply): 0.01 * 0.01 = 0.0001
1109        // 1000 * 0.0001 = 0.1, min is 1.0
1110        assert!(cardinality < 100.0);
1111        assert!(cardinality >= 1.0);
1112    }
1113
1114    #[test]
1115    fn test_filter_or_selectivity() {
1116        let mut estimator = CardinalityEstimator::new();
1117        estimator.add_table_stats("Person", TableStats::new(1000));
1118
1119        // Test OR with two equality predicates
1120        // Each equality has selectivity 0.01
1121        // OR gives: 0.01 + 0.01 - (0.01 * 0.01) = 0.0199
1122        let filter = LogicalOperator::Filter(FilterOp {
1123            predicate: LogicalExpression::Binary {
1124                left: Box::new(LogicalExpression::Binary {
1125                    left: Box::new(LogicalExpression::Property {
1126                        variable: "n".to_string(),
1127                        property: "city".to_string(),
1128                    }),
1129                    op: BinaryOp::Eq,
1130                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1131                }),
1132                op: BinaryOp::Or,
1133                right: Box::new(LogicalExpression::Binary {
1134                    left: Box::new(LogicalExpression::Property {
1135                        variable: "n".to_string(),
1136                        property: "city".to_string(),
1137                    }),
1138                    op: BinaryOp::Eq,
1139                    right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1140                }),
1141            },
1142            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1143                variable: "n".to_string(),
1144                label: Some("Person".to_string()),
1145                input: None,
1146            })),
1147        });
1148
1149        let cardinality = estimator.estimate(&filter);
1150        // OR: 0.01 + 0.01 - 0.0001 ≈ 0.0199, so 1000 * 0.0199 ≈ 19.9
1151        assert!(cardinality < 100.0);
1152        assert!(cardinality >= 1.0);
1153    }
1154
1155    #[test]
1156    fn test_filter_literal_true() {
1157        let mut estimator = CardinalityEstimator::new();
1158        estimator.add_table_stats("Person", TableStats::new(1000));
1159
1160        let filter = LogicalOperator::Filter(FilterOp {
1161            predicate: LogicalExpression::Literal(Value::Bool(true)),
1162            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1163                variable: "n".to_string(),
1164                label: Some("Person".to_string()),
1165                input: None,
1166            })),
1167        });
1168
1169        let cardinality = estimator.estimate(&filter);
1170        // Literal true has selectivity 1.0
1171        assert!((cardinality - 1000.0).abs() < 0.001);
1172    }
1173
1174    #[test]
1175    fn test_filter_literal_false() {
1176        let mut estimator = CardinalityEstimator::new();
1177        estimator.add_table_stats("Person", TableStats::new(1000));
1178
1179        let filter = LogicalOperator::Filter(FilterOp {
1180            predicate: LogicalExpression::Literal(Value::Bool(false)),
1181            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1182                variable: "n".to_string(),
1183                label: Some("Person".to_string()),
1184                input: None,
1185            })),
1186        });
1187
1188        let cardinality = estimator.estimate(&filter);
1189        // Literal false has selectivity 0.0, but min is 1.0
1190        assert!((cardinality - 1.0).abs() < 0.001);
1191    }
1192
1193    #[test]
1194    fn test_unary_not_selectivity() {
1195        let mut estimator = CardinalityEstimator::new();
1196        estimator.add_table_stats("Person", TableStats::new(1000));
1197
1198        let filter = LogicalOperator::Filter(FilterOp {
1199            predicate: LogicalExpression::Unary {
1200                op: UnaryOp::Not,
1201                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1202            },
1203            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1204                variable: "n".to_string(),
1205                label: Some("Person".to_string()),
1206                input: None,
1207            })),
1208        });
1209
1210        let cardinality = estimator.estimate(&filter);
1211        // NOT inverts selectivity
1212        assert!(cardinality < 1000.0);
1213    }
1214
1215    #[test]
1216    fn test_unary_is_null_selectivity() {
1217        let mut estimator = CardinalityEstimator::new();
1218        estimator.add_table_stats("Person", TableStats::new(1000));
1219
1220        let filter = LogicalOperator::Filter(FilterOp {
1221            predicate: LogicalExpression::Unary {
1222                op: UnaryOp::IsNull,
1223                operand: Box::new(LogicalExpression::Variable("x".to_string())),
1224            },
1225            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1226                variable: "n".to_string(),
1227                label: Some("Person".to_string()),
1228                input: None,
1229            })),
1230        });
1231
1232        let cardinality = estimator.estimate(&filter);
1233        // IS NULL has selectivity 0.05
1234        assert!(cardinality < 100.0);
1235    }
1236
1237    #[test]
1238    fn test_expand_cardinality() {
1239        let mut estimator = CardinalityEstimator::new();
1240        estimator.add_table_stats("Person", TableStats::new(100));
1241
1242        let expand = LogicalOperator::Expand(ExpandOp {
1243            from_variable: "a".to_string(),
1244            to_variable: "b".to_string(),
1245            edge_variable: None,
1246            direction: ExpandDirection::Outgoing,
1247            edge_type: None,
1248            min_hops: 1,
1249            max_hops: Some(1),
1250            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1251                variable: "a".to_string(),
1252                label: Some("Person".to_string()),
1253                input: None,
1254            })),
1255            path_alias: None,
1256        });
1257
1258        let cardinality = estimator.estimate(&expand);
1259        // Expand multiplies by fanout (10)
1260        assert!(cardinality > 100.0);
1261    }
1262
1263    #[test]
1264    fn test_expand_with_edge_type_filter() {
1265        let mut estimator = CardinalityEstimator::new();
1266        estimator.add_table_stats("Person", TableStats::new(100));
1267
1268        let expand = LogicalOperator::Expand(ExpandOp {
1269            from_variable: "a".to_string(),
1270            to_variable: "b".to_string(),
1271            edge_variable: None,
1272            direction: ExpandDirection::Outgoing,
1273            edge_type: Some("KNOWS".to_string()),
1274            min_hops: 1,
1275            max_hops: Some(1),
1276            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1277                variable: "a".to_string(),
1278                label: Some("Person".to_string()),
1279                input: None,
1280            })),
1281            path_alias: None,
1282        });
1283
1284        let cardinality = estimator.estimate(&expand);
1285        // With edge type, fanout is reduced by half
1286        assert!(cardinality > 100.0);
1287    }
1288
1289    #[test]
1290    fn test_expand_variable_length() {
1291        let mut estimator = CardinalityEstimator::new();
1292        estimator.add_table_stats("Person", TableStats::new(100));
1293
1294        let expand = LogicalOperator::Expand(ExpandOp {
1295            from_variable: "a".to_string(),
1296            to_variable: "b".to_string(),
1297            edge_variable: None,
1298            direction: ExpandDirection::Outgoing,
1299            edge_type: None,
1300            min_hops: 1,
1301            max_hops: Some(3),
1302            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1303                variable: "a".to_string(),
1304                label: Some("Person".to_string()),
1305                input: None,
1306            })),
1307            path_alias: None,
1308        });
1309
1310        let cardinality = estimator.estimate(&expand);
1311        // Variable length path has much higher cardinality
1312        assert!(cardinality > 500.0);
1313    }
1314
1315    #[test]
1316    fn test_join_cross_product() {
1317        let mut estimator = CardinalityEstimator::new();
1318        estimator.add_table_stats("Person", TableStats::new(100));
1319        estimator.add_table_stats("Company", TableStats::new(50));
1320
1321        let join = LogicalOperator::Join(JoinOp {
1322            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1323                variable: "p".to_string(),
1324                label: Some("Person".to_string()),
1325                input: None,
1326            })),
1327            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1328                variable: "c".to_string(),
1329                label: Some("Company".to_string()),
1330                input: None,
1331            })),
1332            join_type: JoinType::Cross,
1333            conditions: vec![],
1334        });
1335
1336        let cardinality = estimator.estimate(&join);
1337        // Cross join = 100 * 50 = 5000
1338        assert!((cardinality - 5000.0).abs() < 0.001);
1339    }
1340
1341    #[test]
1342    fn test_join_left_outer() {
1343        let mut estimator = CardinalityEstimator::new();
1344        estimator.add_table_stats("Person", TableStats::new(1000));
1345        estimator.add_table_stats("Company", TableStats::new(10));
1346
1347        let join = LogicalOperator::Join(JoinOp {
1348            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1349                variable: "p".to_string(),
1350                label: Some("Person".to_string()),
1351                input: None,
1352            })),
1353            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1354                variable: "c".to_string(),
1355                label: Some("Company".to_string()),
1356                input: None,
1357            })),
1358            join_type: JoinType::Left,
1359            conditions: vec![JoinCondition {
1360                left: LogicalExpression::Variable("p".to_string()),
1361                right: LogicalExpression::Variable("c".to_string()),
1362            }],
1363        });
1364
1365        let cardinality = estimator.estimate(&join);
1366        // Left join returns at least all left rows
1367        assert!(cardinality >= 1000.0);
1368    }
1369
1370    #[test]
1371    fn test_join_semi() {
1372        let mut estimator = CardinalityEstimator::new();
1373        estimator.add_table_stats("Person", TableStats::new(1000));
1374        estimator.add_table_stats("Company", TableStats::new(100));
1375
1376        let join = LogicalOperator::Join(JoinOp {
1377            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1378                variable: "p".to_string(),
1379                label: Some("Person".to_string()),
1380                input: None,
1381            })),
1382            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1383                variable: "c".to_string(),
1384                label: Some("Company".to_string()),
1385                input: None,
1386            })),
1387            join_type: JoinType::Semi,
1388            conditions: vec![],
1389        });
1390
1391        let cardinality = estimator.estimate(&join);
1392        // Semi join returns at most left cardinality
1393        assert!(cardinality <= 1000.0);
1394    }
1395
1396    #[test]
1397    fn test_join_anti() {
1398        let mut estimator = CardinalityEstimator::new();
1399        estimator.add_table_stats("Person", TableStats::new(1000));
1400        estimator.add_table_stats("Company", TableStats::new(100));
1401
1402        let join = LogicalOperator::Join(JoinOp {
1403            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1404                variable: "p".to_string(),
1405                label: Some("Person".to_string()),
1406                input: None,
1407            })),
1408            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1409                variable: "c".to_string(),
1410                label: Some("Company".to_string()),
1411                input: None,
1412            })),
1413            join_type: JoinType::Anti,
1414            conditions: vec![],
1415        });
1416
1417        let cardinality = estimator.estimate(&join);
1418        // Anti join returns at most left cardinality
1419        assert!(cardinality <= 1000.0);
1420        assert!(cardinality >= 1.0);
1421    }
1422
1423    #[test]
1424    fn test_project_preserves_cardinality() {
1425        let mut estimator = CardinalityEstimator::new();
1426        estimator.add_table_stats("Person", TableStats::new(1000));
1427
1428        let project = LogicalOperator::Project(ProjectOp {
1429            projections: vec![Projection {
1430                expression: LogicalExpression::Variable("n".to_string()),
1431                alias: None,
1432            }],
1433            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1434                variable: "n".to_string(),
1435                label: Some("Person".to_string()),
1436                input: None,
1437            })),
1438        });
1439
1440        let cardinality = estimator.estimate(&project);
1441        assert!((cardinality - 1000.0).abs() < 0.001);
1442    }
1443
1444    #[test]
1445    fn test_sort_preserves_cardinality() {
1446        let mut estimator = CardinalityEstimator::new();
1447        estimator.add_table_stats("Person", TableStats::new(1000));
1448
1449        let sort = LogicalOperator::Sort(SortOp {
1450            keys: vec![SortKey {
1451                expression: LogicalExpression::Variable("n".to_string()),
1452                order: SortOrder::Ascending,
1453            }],
1454            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1455                variable: "n".to_string(),
1456                label: Some("Person".to_string()),
1457                input: None,
1458            })),
1459        });
1460
1461        let cardinality = estimator.estimate(&sort);
1462        assert!((cardinality - 1000.0).abs() < 0.001);
1463    }
1464
1465    #[test]
1466    fn test_distinct_reduces_cardinality() {
1467        let mut estimator = CardinalityEstimator::new();
1468        estimator.add_table_stats("Person", TableStats::new(1000));
1469
1470        let distinct = LogicalOperator::Distinct(DistinctOp {
1471            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1472                variable: "n".to_string(),
1473                label: Some("Person".to_string()),
1474                input: None,
1475            })),
1476            columns: None,
1477        });
1478
1479        let cardinality = estimator.estimate(&distinct);
1480        // Distinct assumes 50% unique
1481        assert!((cardinality - 500.0).abs() < 0.001);
1482    }
1483
1484    #[test]
1485    fn test_skip_reduces_cardinality() {
1486        let mut estimator = CardinalityEstimator::new();
1487        estimator.add_table_stats("Person", TableStats::new(1000));
1488
1489        let skip = LogicalOperator::Skip(SkipOp {
1490            count: 100,
1491            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1492                variable: "n".to_string(),
1493                label: Some("Person".to_string()),
1494                input: None,
1495            })),
1496        });
1497
1498        let cardinality = estimator.estimate(&skip);
1499        assert!((cardinality - 900.0).abs() < 0.001);
1500    }
1501
1502    #[test]
1503    fn test_return_preserves_cardinality() {
1504        let mut estimator = CardinalityEstimator::new();
1505        estimator.add_table_stats("Person", TableStats::new(1000));
1506
1507        let ret = LogicalOperator::Return(ReturnOp {
1508            items: vec![ReturnItem {
1509                expression: LogicalExpression::Variable("n".to_string()),
1510                alias: None,
1511            }],
1512            distinct: false,
1513            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1514                variable: "n".to_string(),
1515                label: Some("Person".to_string()),
1516                input: None,
1517            })),
1518        });
1519
1520        let cardinality = estimator.estimate(&ret);
1521        assert!((cardinality - 1000.0).abs() < 0.001);
1522    }
1523
1524    #[test]
1525    fn test_empty_cardinality() {
1526        let estimator = CardinalityEstimator::new();
1527        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1528        assert!((cardinality).abs() < 0.001);
1529    }
1530
1531    #[test]
1532    fn test_table_stats_with_column() {
1533        let stats = TableStats::new(1000).with_column(
1534            "age",
1535            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1536        );
1537
1538        assert_eq!(stats.row_count, 1000);
1539        let col = stats.columns.get("age").unwrap();
1540        assert_eq!(col.distinct_count, 50);
1541        assert_eq!(col.null_count, 10);
1542        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1543        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1544    }
1545
1546    #[test]
1547    fn test_estimator_default() {
1548        let estimator = CardinalityEstimator::default();
1549        let scan = LogicalOperator::NodeScan(NodeScanOp {
1550            variable: "n".to_string(),
1551            label: None,
1552            input: None,
1553        });
1554        let cardinality = estimator.estimate(&scan);
1555        assert!((cardinality - 1000.0).abs() < 0.001);
1556    }
1557
1558    #[test]
1559    fn test_set_avg_fanout() {
1560        let mut estimator = CardinalityEstimator::new();
1561        estimator.add_table_stats("Person", TableStats::new(100));
1562        estimator.set_avg_fanout(5.0);
1563
1564        let expand = LogicalOperator::Expand(ExpandOp {
1565            from_variable: "a".to_string(),
1566            to_variable: "b".to_string(),
1567            edge_variable: None,
1568            direction: ExpandDirection::Outgoing,
1569            edge_type: None,
1570            min_hops: 1,
1571            max_hops: Some(1),
1572            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1573                variable: "a".to_string(),
1574                label: Some("Person".to_string()),
1575                input: None,
1576            })),
1577            path_alias: None,
1578        });
1579
1580        let cardinality = estimator.estimate(&expand);
1581        // With fanout 5: 100 * 5 = 500
1582        assert!((cardinality - 500.0).abs() < 0.001);
1583    }
1584
1585    #[test]
1586    fn test_multiple_group_by_keys_reduce_cardinality() {
1587        // The current implementation uses a simplified model where more group by keys
1588        // results in greater reduction (dividing by 10^num_keys). This is a simplification
1589        // that works for most cases where group by keys are correlated.
1590        let mut estimator = CardinalityEstimator::new();
1591        estimator.add_table_stats("Person", TableStats::new(10000));
1592
1593        let single_group = LogicalOperator::Aggregate(AggregateOp {
1594            group_by: vec![LogicalExpression::Property {
1595                variable: "n".to_string(),
1596                property: "city".to_string(),
1597            }],
1598            aggregates: vec![],
1599            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1600                variable: "n".to_string(),
1601                label: Some("Person".to_string()),
1602                input: None,
1603            })),
1604            having: None,
1605        });
1606
1607        let multi_group = LogicalOperator::Aggregate(AggregateOp {
1608            group_by: vec![
1609                LogicalExpression::Property {
1610                    variable: "n".to_string(),
1611                    property: "city".to_string(),
1612                },
1613                LogicalExpression::Property {
1614                    variable: "n".to_string(),
1615                    property: "country".to_string(),
1616                },
1617            ],
1618            aggregates: vec![],
1619            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1620                variable: "n".to_string(),
1621                label: Some("Person".to_string()),
1622                input: None,
1623            })),
1624            having: None,
1625        });
1626
1627        let single_card = estimator.estimate(&single_group);
1628        let multi_card = estimator.estimate(&multi_group);
1629
1630        // Both should reduce cardinality from input
1631        assert!(single_card < 10000.0);
1632        assert!(multi_card < 10000.0);
1633        // Both should be at least 1
1634        assert!(single_card >= 1.0);
1635        assert!(multi_card >= 1.0);
1636    }
1637
1638    // ============= Histogram Tests =============
1639
1640    #[test]
1641    fn test_histogram_build_uniform() {
1642        // Build histogram from uniformly distributed data
1643        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1644        let histogram = EquiDepthHistogram::build(&values, 10);
1645
1646        assert_eq!(histogram.num_buckets(), 10);
1647        assert_eq!(histogram.total_rows(), 100);
1648
1649        // Each bucket should have approximately 10 rows
1650        for bucket in histogram.buckets() {
1651            assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1652        }
1653    }
1654
1655    #[test]
1656    fn test_histogram_build_skewed() {
1657        // Build histogram from skewed data (many small values, few large)
1658        let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1659        values.extend((0..20).map(|i| 1000.0 + i as f64));
1660        let histogram = EquiDepthHistogram::build(&values, 5);
1661
1662        assert_eq!(histogram.num_buckets(), 5);
1663        assert_eq!(histogram.total_rows(), 100);
1664
1665        // Each bucket should have ~20 rows despite skewed data
1666        for bucket in histogram.buckets() {
1667            assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1668        }
1669    }
1670
1671    #[test]
1672    fn test_histogram_range_selectivity_full() {
1673        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1674        let histogram = EquiDepthHistogram::build(&values, 10);
1675
1676        // Full range should have selectivity ~1.0
1677        let selectivity = histogram.range_selectivity(None, None);
1678        assert!((selectivity - 1.0).abs() < 0.01);
1679    }
1680
1681    #[test]
1682    fn test_histogram_range_selectivity_half() {
1683        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1684        let histogram = EquiDepthHistogram::build(&values, 10);
1685
1686        // Values >= 50 should be ~50% (half the data)
1687        let selectivity = histogram.range_selectivity(Some(50.0), None);
1688        assert!(selectivity > 0.4 && selectivity < 0.6);
1689    }
1690
1691    #[test]
1692    fn test_histogram_range_selectivity_quarter() {
1693        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1694        let histogram = EquiDepthHistogram::build(&values, 10);
1695
1696        // Values < 25 should be ~25%
1697        let selectivity = histogram.range_selectivity(None, Some(25.0));
1698        assert!(selectivity > 0.2 && selectivity < 0.3);
1699    }
1700
1701    #[test]
1702    fn test_histogram_equality_selectivity() {
1703        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1704        let histogram = EquiDepthHistogram::build(&values, 10);
1705
1706        // Equality on 100 distinct values should be ~1%
1707        let selectivity = histogram.equality_selectivity(50.0);
1708        assert!(selectivity > 0.005 && selectivity < 0.02);
1709    }
1710
1711    #[test]
1712    fn test_histogram_empty() {
1713        let histogram = EquiDepthHistogram::build(&[], 10);
1714
1715        assert_eq!(histogram.num_buckets(), 0);
1716        assert_eq!(histogram.total_rows(), 0);
1717
1718        // Default selectivity for empty histogram
1719        let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1720        assert!((selectivity - 0.33).abs() < 0.01);
1721    }
1722
1723    #[test]
1724    fn test_histogram_bucket_overlap() {
1725        let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1726
1727        // Full overlap
1728        assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1729
1730        // Half overlap (lower half)
1731        assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1732
1733        // Half overlap (upper half)
1734        assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1735
1736        // No overlap (below)
1737        assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1738
1739        // No overlap (above)
1740        assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1741    }
1742
1743    #[test]
1744    fn test_column_stats_from_values() {
1745        let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
1746        let stats = ColumnStats::from_values(values, 4);
1747
1748        assert_eq!(stats.distinct_count, 5); // 10, 20, 30, 40, 50
1749        assert!(stats.min_value.is_some());
1750        assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
1751        assert!(stats.max_value.is_some());
1752        assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
1753        assert!(stats.histogram.is_some());
1754    }
1755
1756    #[test]
1757    fn test_column_stats_with_histogram_builder() {
1758        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1759        let histogram = EquiDepthHistogram::build(&values, 10);
1760
1761        let stats = ColumnStats::new(100)
1762            .with_range(0.0, 99.0)
1763            .with_histogram(histogram);
1764
1765        assert!(stats.histogram.is_some());
1766        assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
1767    }
1768
1769    #[test]
1770    fn test_filter_with_histogram_stats() {
1771        let mut estimator = CardinalityEstimator::new();
1772
1773        // Create stats with histogram for age column
1774        let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
1775        let histogram = EquiDepthHistogram::build(&age_values, 10);
1776        let age_stats = ColumnStats::new(62)
1777            .with_range(18.0, 79.0)
1778            .with_histogram(histogram);
1779
1780        estimator.add_table_stats(
1781            "Person",
1782            TableStats::new(1000).with_column("age", age_stats),
1783        );
1784
1785        // Filter: age > 50
1786        // Age range is 18-79, so >50 is about (79-50)/(79-18) = 29/61 ≈ 47.5%
1787        let filter = LogicalOperator::Filter(FilterOp {
1788            predicate: LogicalExpression::Binary {
1789                left: Box::new(LogicalExpression::Property {
1790                    variable: "n".to_string(),
1791                    property: "age".to_string(),
1792                }),
1793                op: BinaryOp::Gt,
1794                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1795            },
1796            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1797                variable: "n".to_string(),
1798                label: Some("Person".to_string()),
1799                input: None,
1800            })),
1801        });
1802
1803        let cardinality = estimator.estimate(&filter);
1804
1805        // With histogram, should get more accurate estimate than default 0.33
1806        // Expected: ~47.5% of 1000 = ~475
1807        assert!(cardinality > 300.0 && cardinality < 600.0);
1808    }
1809
1810    #[test]
1811    fn test_filter_equality_with_histogram() {
1812        let mut estimator = CardinalityEstimator::new();
1813
1814        // Create stats with histogram
1815        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1816        let histogram = EquiDepthHistogram::build(&values, 10);
1817        let stats = ColumnStats::new(100)
1818            .with_range(0.0, 99.0)
1819            .with_histogram(histogram);
1820
1821        estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
1822
1823        // Filter: value = 50
1824        let filter = LogicalOperator::Filter(FilterOp {
1825            predicate: LogicalExpression::Binary {
1826                left: Box::new(LogicalExpression::Property {
1827                    variable: "d".to_string(),
1828                    property: "value".to_string(),
1829                }),
1830                op: BinaryOp::Eq,
1831                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1832            },
1833            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1834                variable: "d".to_string(),
1835                label: Some("Data".to_string()),
1836                input: None,
1837            })),
1838        });
1839
1840        let cardinality = estimator.estimate(&filter);
1841
1842        // With 100 distinct values, selectivity should be ~1%
1843        // 1000 * 0.01 = 10
1844        assert!(cardinality >= 1.0 && cardinality < 50.0);
1845    }
1846
1847    #[test]
1848    fn test_histogram_min_max() {
1849        let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
1850        let histogram = EquiDepthHistogram::build(&values, 2);
1851
1852        assert_eq!(histogram.min_value(), Some(5.0));
1853        // Max is the upper bound of the last bucket
1854        assert!(histogram.max_value().is_some());
1855    }
1856}