Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4//!
5//! # Equi-Depth Histograms
6//!
7//! This module provides equi-depth histogram support for accurate selectivity
8//! estimation of range predicates. Unlike equi-width histograms that divide
9//! the value range into equal-sized buckets, equi-depth histograms divide
10//! the data into buckets with approximately equal numbers of rows.
11//!
12//! Benefits:
13//! - Better estimates for skewed data distributions
14//! - More accurate range selectivity than assuming uniform distribution
15//! - Adaptive to actual data characteristics
16
17use crate::query::plan::{
18    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19    LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
20    VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24/// A bucket in an equi-depth histogram.
25///
26/// Each bucket represents a range of values and the frequency of rows
27/// falling within that range. In an equi-depth histogram, all buckets
28/// contain approximately the same number of rows.
29#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31    /// Lower bound of the bucket (inclusive).
32    pub lower_bound: f64,
33    /// Upper bound of the bucket (exclusive, except for the last bucket).
34    pub upper_bound: f64,
35    /// Number of rows in this bucket.
36    pub frequency: u64,
37    /// Number of distinct values in this bucket.
38    pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42    /// Creates a new histogram bucket.
43    #[must_use]
44    pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45        Self {
46            lower_bound,
47            upper_bound,
48            frequency,
49            distinct_count,
50        }
51    }
52
53    /// Returns the width of this bucket.
54    #[must_use]
55    pub fn width(&self) -> f64 {
56        self.upper_bound - self.lower_bound
57    }
58
59    /// Checks if a value falls within this bucket.
60    #[must_use]
61    pub fn contains(&self, value: f64) -> bool {
62        value >= self.lower_bound && value < self.upper_bound
63    }
64
65    /// Returns the fraction of this bucket covered by the given range.
66    #[must_use]
67    pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68        let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69        let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71        let bucket_width = self.width();
72        if bucket_width <= 0.0 {
73            return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74                1.0
75            } else {
76                0.0
77            };
78        }
79
80        let overlap = (effective_upper - effective_lower).max(0.0);
81        (overlap / bucket_width).min(1.0)
82    }
83}
84
85/// An equi-depth histogram for selectivity estimation.
86///
87/// Equi-depth histograms partition data into buckets where each bucket
88/// contains approximately the same number of rows. This provides more
89/// accurate selectivity estimates than assuming uniform distribution,
90/// especially for skewed data.
91///
92/// # Example
93///
94/// ```no_run
95/// use grafeo_engine::query::optimizer::cardinality::EquiDepthHistogram;
96///
97/// // Build a histogram from sorted values
98/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0];
99/// let histogram = EquiDepthHistogram::build(&values, 4);
100///
101/// // Estimate selectivity for age > 25
102/// let selectivity = histogram.range_selectivity(Some(25.0), None);
103/// ```
104#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106    /// The histogram buckets, sorted by lower_bound.
107    buckets: Vec<HistogramBucket>,
108    /// Total number of rows represented by this histogram.
109    total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113    /// Creates a new histogram from pre-built buckets.
114    #[must_use]
115    pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116        let total_rows = buckets.iter().map(|b| b.frequency).sum();
117        Self {
118            buckets,
119            total_rows,
120        }
121    }
122
123    /// Builds an equi-depth histogram from a sorted slice of values.
124    ///
125    /// # Arguments
126    /// * `values` - A sorted slice of numeric values
127    /// * `num_buckets` - The desired number of buckets
128    ///
129    /// # Returns
130    /// An equi-depth histogram with approximately equal row counts per bucket.
131    #[must_use]
132    pub fn build(values: &[f64], num_buckets: usize) -> Self {
133        if values.is_empty() || num_buckets == 0 {
134            return Self {
135                buckets: Vec::new(),
136                total_rows: 0,
137            };
138        }
139
140        let num_buckets = num_buckets.min(values.len());
141        let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142        let mut buckets = Vec::with_capacity(num_buckets);
143
144        let mut start_idx = 0;
145        while start_idx < values.len() {
146            let end_idx = (start_idx + rows_per_bucket).min(values.len());
147            let lower_bound = values[start_idx];
148            let upper_bound = if end_idx < values.len() {
149                values[end_idx]
150            } else {
151                // For the last bucket, extend slightly beyond the max value
152                values[end_idx - 1] + 1.0
153            };
154
155            // Count distinct values in this bucket
156            let bucket_values = &values[start_idx..end_idx];
157            let distinct_count = count_distinct(bucket_values);
158
159            buckets.push(HistogramBucket::new(
160                lower_bound,
161                upper_bound,
162                (end_idx - start_idx) as u64,
163                distinct_count,
164            ));
165
166            start_idx = end_idx;
167        }
168
169        Self::new(buckets)
170    }
171
172    /// Returns the number of buckets in this histogram.
173    #[must_use]
174    pub fn num_buckets(&self) -> usize {
175        self.buckets.len()
176    }
177
178    /// Returns the total number of rows represented.
179    #[must_use]
180    pub fn total_rows(&self) -> u64 {
181        self.total_rows
182    }
183
184    /// Returns the histogram buckets.
185    #[must_use]
186    pub fn buckets(&self) -> &[HistogramBucket] {
187        &self.buckets
188    }
189
190    /// Estimates selectivity for a range predicate.
191    ///
192    /// # Arguments
193    /// * `lower` - Lower bound (None for unbounded)
194    /// * `upper` - Upper bound (None for unbounded)
195    ///
196    /// # Returns
197    /// Estimated fraction of rows matching the range (0.0 to 1.0).
198    #[must_use]
199    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200        if self.buckets.is_empty() || self.total_rows == 0 {
201            return 0.33; // Default fallback
202        }
203
204        let mut matching_rows = 0.0;
205
206        for bucket in &self.buckets {
207            // Check if this bucket overlaps with the range
208            let bucket_lower = bucket.lower_bound;
209            let bucket_upper = bucket.upper_bound;
210
211            // Skip buckets entirely outside the range
212            if let Some(l) = lower
213                && bucket_upper <= l
214            {
215                continue;
216            }
217            if let Some(u) = upper
218                && bucket_lower >= u
219            {
220                continue;
221            }
222
223            // Calculate the fraction of this bucket covered by the range
224            let overlap = bucket.overlap_fraction(lower, upper);
225            matching_rows += overlap * bucket.frequency as f64;
226        }
227
228        (matching_rows / self.total_rows as f64).min(1.0).max(0.0)
229    }
230
231    /// Estimates selectivity for an equality predicate.
232    ///
233    /// Uses the distinct count within matching buckets for better accuracy.
234    #[must_use]
235    pub fn equality_selectivity(&self, value: f64) -> f64 {
236        if self.buckets.is_empty() || self.total_rows == 0 {
237            return 0.01; // Default fallback
238        }
239
240        // Find the bucket containing this value
241        for bucket in &self.buckets {
242            if bucket.contains(value) {
243                // Assume uniform distribution within bucket
244                if bucket.distinct_count > 0 {
245                    return (bucket.frequency as f64
246                        / bucket.distinct_count as f64
247                        / self.total_rows as f64)
248                        .min(1.0);
249                }
250            }
251        }
252
253        // Value not in any bucket - very low selectivity
254        0.001
255    }
256
257    /// Gets the minimum value in the histogram.
258    #[must_use]
259    pub fn min_value(&self) -> Option<f64> {
260        self.buckets.first().map(|b| b.lower_bound)
261    }
262
263    /// Gets the maximum value in the histogram.
264    #[must_use]
265    pub fn max_value(&self) -> Option<f64> {
266        self.buckets.last().map(|b| b.upper_bound)
267    }
268}
269
270/// Counts distinct values in a sorted slice.
271fn count_distinct(sorted_values: &[f64]) -> u64 {
272    if sorted_values.is_empty() {
273        return 0;
274    }
275
276    let mut count = 1u64;
277    let mut prev = sorted_values[0];
278
279    for &val in &sorted_values[1..] {
280        if (val - prev).abs() > f64::EPSILON {
281            count += 1;
282            prev = val;
283        }
284    }
285
286    count
287}
288
289/// Statistics for a table/label.
290#[derive(Debug, Clone)]
291pub struct TableStats {
292    /// Total number of rows.
293    pub row_count: u64,
294    /// Column statistics.
295    pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299    /// Creates new table statistics.
300    #[must_use]
301    pub fn new(row_count: u64) -> Self {
302        Self {
303            row_count,
304            columns: HashMap::new(),
305        }
306    }
307
308    /// Adds column statistics.
309    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310        self.columns.insert(name.to_string(), stats);
311        self
312    }
313}
314
315/// Statistics for a column.
316#[derive(Debug, Clone)]
317pub struct ColumnStats {
318    /// Number of distinct values.
319    pub distinct_count: u64,
320    /// Number of null values.
321    pub null_count: u64,
322    /// Minimum value (if orderable).
323    pub min_value: Option<f64>,
324    /// Maximum value (if orderable).
325    pub max_value: Option<f64>,
326    /// Equi-depth histogram for accurate selectivity estimation.
327    pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331    /// Creates new column statistics.
332    #[must_use]
333    pub fn new(distinct_count: u64) -> Self {
334        Self {
335            distinct_count,
336            null_count: 0,
337            min_value: None,
338            max_value: None,
339            histogram: None,
340        }
341    }
342
343    /// Sets the null count.
344    #[must_use]
345    pub fn with_nulls(mut self, null_count: u64) -> Self {
346        self.null_count = null_count;
347        self
348    }
349
350    /// Sets the min/max range.
351    #[must_use]
352    pub fn with_range(mut self, min: f64, max: f64) -> Self {
353        self.min_value = Some(min);
354        self.max_value = Some(max);
355        self
356    }
357
358    /// Sets the equi-depth histogram for this column.
359    #[must_use]
360    pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361        self.histogram = Some(histogram);
362        self
363    }
364
365    /// Builds column statistics with histogram from raw values.
366    ///
367    /// This is a convenience method that computes all statistics from the data.
368    ///
369    /// # Arguments
370    /// * `values` - The column values (will be sorted internally)
371    /// * `num_buckets` - Number of histogram buckets to create
372    #[must_use]
373    pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374        if values.is_empty() {
375            return Self::new(0);
376        }
377
378        // Sort values for histogram building
379        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381        let min = values.first().copied();
382        let max = values.last().copied();
383        let distinct_count = count_distinct(&values);
384        let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386        Self {
387            distinct_count,
388            null_count: 0,
389            min_value: min,
390            max_value: max,
391            histogram: Some(histogram),
392        }
393    }
394}
395
396/// Configurable selectivity defaults for cardinality estimation.
397///
398/// Controls the assumed selectivity for various predicate types when
399/// histogram or column statistics are unavailable. Adjusting these
400/// values can improve plan quality for workloads with known skew.
401#[derive(Debug, Clone)]
402pub struct SelectivityConfig {
403    /// Selectivity for unknown predicates (default: 0.1).
404    pub default: f64,
405    /// Selectivity for equality predicates without stats (default: 0.01).
406    pub equality: f64,
407    /// Selectivity for inequality predicates (default: 0.99).
408    pub inequality: f64,
409    /// Selectivity for range predicates without stats (default: 0.33).
410    pub range: f64,
411    /// Selectivity for string operations: STARTS WITH, ENDS WITH, CONTAINS, LIKE (default: 0.1).
412    pub string_ops: f64,
413    /// Selectivity for IN membership (default: 0.1).
414    pub membership: f64,
415    /// Selectivity for IS NULL (default: 0.05).
416    pub is_null: f64,
417    /// Selectivity for IS NOT NULL (default: 0.95).
418    pub is_not_null: f64,
419    /// Fraction assumed distinct for DISTINCT operations (default: 0.5).
420    pub distinct_fraction: f64,
421}
422
423impl SelectivityConfig {
424    /// Creates a new config with standard database defaults.
425    #[must_use]
426    pub fn new() -> Self {
427        Self {
428            default: 0.1,
429            equality: 0.01,
430            inequality: 0.99,
431            range: 0.33,
432            string_ops: 0.1,
433            membership: 0.1,
434            is_null: 0.05,
435            is_not_null: 0.95,
436            distinct_fraction: 0.5,
437        }
438    }
439}
440
441impl Default for SelectivityConfig {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447/// A single estimate-vs-actual observation for analysis.
448#[derive(Debug, Clone)]
449pub struct EstimationEntry {
450    /// Human-readable label for the operator (e.g., "NodeScan(Person)").
451    pub operator: String,
452    /// The cardinality estimate produced by the optimizer.
453    pub estimated: f64,
454    /// The actual row count observed at execution time.
455    pub actual: f64,
456}
457
458impl EstimationEntry {
459    /// Returns the estimation error ratio (actual / estimated).
460    ///
461    /// Values near 1.0 indicate accurate estimates.
462    /// Values > 1.0 indicate underestimation.
463    /// Values < 1.0 indicate overestimation.
464    #[must_use]
465    pub fn error_ratio(&self) -> f64 {
466        if self.estimated.abs() < f64::EPSILON {
467            if self.actual.abs() < f64::EPSILON {
468                1.0
469            } else {
470                f64::INFINITY
471            }
472        } else {
473            self.actual / self.estimated
474        }
475    }
476}
477
478/// Collects estimate vs actual cardinality data for query plan analysis.
479///
480/// After executing a query, call [`record()`](Self::record) for each
481/// operator with its estimated and actual cardinalities. Then use
482/// [`should_replan()`](Self::should_replan) to decide whether the plan
483/// should be re-optimized.
484#[derive(Debug, Clone, Default)]
485pub struct EstimationLog {
486    /// Recorded entries.
487    entries: Vec<EstimationEntry>,
488    /// Error ratio threshold that triggers re-planning (default: 10.0).
489    ///
490    /// If any operator's error ratio exceeds this, `should_replan()` returns true.
491    replan_threshold: f64,
492}
493
494impl EstimationLog {
495    /// Creates a new estimation log with the given re-planning threshold.
496    #[must_use]
497    pub fn new(replan_threshold: f64) -> Self {
498        Self {
499            entries: Vec::new(),
500            replan_threshold,
501        }
502    }
503
504    /// Records an estimate-vs-actual observation.
505    pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
506        self.entries.push(EstimationEntry {
507            operator: operator.into(),
508            estimated,
509            actual,
510        });
511    }
512
513    /// Returns all recorded entries.
514    #[must_use]
515    pub fn entries(&self) -> &[EstimationEntry] {
516        &self.entries
517    }
518
519    /// Returns whether any operator's estimation error exceeds the threshold,
520    /// indicating the plan should be re-optimized.
521    #[must_use]
522    pub fn should_replan(&self) -> bool {
523        self.entries.iter().any(|e| {
524            let ratio = e.error_ratio();
525            ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
526        })
527    }
528
529    /// Returns the maximum error ratio across all entries.
530    #[must_use]
531    pub fn max_error_ratio(&self) -> f64 {
532        self.entries
533            .iter()
534            .map(|e| {
535                let r = e.error_ratio();
536                // Normalize so both over- and under-estimation are > 1.0
537                if r < 1.0 { 1.0 / r } else { r }
538            })
539            .fold(1.0_f64, f64::max)
540    }
541
542    /// Clears all entries.
543    pub fn clear(&mut self) {
544        self.entries.clear();
545    }
546}
547
548/// Cardinality estimator.
549pub struct CardinalityEstimator {
550    /// Statistics for each label/table.
551    table_stats: HashMap<String, TableStats>,
552    /// Default row count for unknown tables.
553    default_row_count: u64,
554    /// Default selectivity for unknown predicates.
555    default_selectivity: f64,
556    /// Average edge fanout (outgoing edges per node).
557    avg_fanout: f64,
558    /// Configurable selectivity defaults.
559    selectivity_config: SelectivityConfig,
560}
561
562impl CardinalityEstimator {
563    /// Creates a new cardinality estimator with default settings.
564    #[must_use]
565    pub fn new() -> Self {
566        let config = SelectivityConfig::new();
567        Self {
568            table_stats: HashMap::new(),
569            default_row_count: 1000,
570            default_selectivity: config.default,
571            avg_fanout: 10.0,
572            selectivity_config: config,
573        }
574    }
575
576    /// Creates a new cardinality estimator with custom selectivity configuration.
577    #[must_use]
578    pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
579        Self {
580            table_stats: HashMap::new(),
581            default_row_count: 1000,
582            default_selectivity: config.default,
583            avg_fanout: 10.0,
584            selectivity_config: config,
585        }
586    }
587
588    /// Returns the current selectivity configuration.
589    #[must_use]
590    pub fn selectivity_config(&self) -> &SelectivityConfig {
591        &self.selectivity_config
592    }
593
594    /// Creates an estimation log with the default re-planning threshold (10x).
595    #[must_use]
596    pub fn create_estimation_log() -> EstimationLog {
597        EstimationLog::new(10.0)
598    }
599
600    /// Creates a cardinality estimator pre-populated from store statistics.
601    ///
602    /// Maps `LabelStatistics` to `TableStats` and computes the average edge
603    /// fanout from `EdgeTypeStatistics`. Falls back to defaults for any
604    /// missing statistics.
605    #[must_use]
606    pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
607        let mut estimator = Self::new();
608
609        // Use total node count as default for unlabeled scans
610        if stats.total_nodes > 0 {
611            estimator.default_row_count = stats.total_nodes;
612        }
613
614        // Convert label statistics to optimizer table stats
615        for (label, label_stats) in &stats.labels {
616            let mut table_stats = TableStats::new(label_stats.node_count);
617
618            // Map property statistics (distinct count for selectivity estimation)
619            for (prop, col_stats) in &label_stats.properties {
620                let optimizer_col =
621                    ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
622                table_stats = table_stats.with_column(prop, optimizer_col);
623            }
624
625            estimator.add_table_stats(label, table_stats);
626        }
627
628        // Compute average fanout from edge type statistics
629        if !stats.edge_types.is_empty() {
630            let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
631            estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
632        } else if stats.total_nodes > 0 {
633            estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
634        }
635
636        // Clamp fanout to a reasonable minimum
637        if estimator.avg_fanout < 1.0 {
638            estimator.avg_fanout = 1.0;
639        }
640
641        estimator
642    }
643
644    /// Adds statistics for a table/label.
645    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
646        self.table_stats.insert(name.to_string(), stats);
647    }
648
649    /// Sets the average edge fanout.
650    pub fn set_avg_fanout(&mut self, fanout: f64) {
651        self.avg_fanout = fanout;
652    }
653
654    /// Estimates the cardinality of a logical operator.
655    #[must_use]
656    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
657        match op {
658            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
659            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
660            LogicalOperator::Project(project) => self.estimate_project(project),
661            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
662            LogicalOperator::Join(join) => self.estimate_join(join),
663            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
664            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
665            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
666            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
667            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
668            LogicalOperator::Return(ret) => self.estimate(&ret.input),
669            LogicalOperator::Empty => 0.0,
670            LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
671            LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
672            _ => self.default_row_count as f64,
673        }
674    }
675
676    /// Estimates node scan cardinality.
677    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
678        if let Some(label) = &scan.label
679            && let Some(stats) = self.table_stats.get(label)
680        {
681            return stats.row_count as f64;
682        }
683        // No label filter - scan all nodes
684        self.default_row_count as f64
685    }
686
687    /// Estimates filter cardinality.
688    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
689        let input_cardinality = self.estimate(&filter.input);
690        let selectivity = self.estimate_selectivity(&filter.predicate);
691        (input_cardinality * selectivity).max(1.0)
692    }
693
694    /// Estimates projection cardinality (same as input).
695    fn estimate_project(&self, project: &ProjectOp) -> f64 {
696        self.estimate(&project.input)
697    }
698
699    /// Estimates expand cardinality.
700    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
701        let input_cardinality = self.estimate(&expand.input);
702
703        // Apply fanout based on edge type
704        let fanout = if !expand.edge_types.is_empty() {
705            // Specific edge type(s) typically have lower fanout
706            self.avg_fanout * 0.5
707        } else {
708            self.avg_fanout
709        };
710
711        // Handle variable-length paths
712        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
713            let min = expand.min_hops as f64;
714            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
715            // Geometric series approximation
716            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
717        } else {
718            fanout
719        };
720
721        (input_cardinality * path_multiplier).max(1.0)
722    }
723
724    /// Estimates join cardinality.
725    fn estimate_join(&self, join: &JoinOp) -> f64 {
726        let left_card = self.estimate(&join.left);
727        let right_card = self.estimate(&join.right);
728
729        match join.join_type {
730            JoinType::Cross => left_card * right_card,
731            JoinType::Inner => {
732                // Assume join selectivity based on conditions
733                let selectivity = if join.conditions.is_empty() {
734                    1.0 // Cross join
735                } else {
736                    // Estimate based on number of conditions
737                    0.1_f64.powi(join.conditions.len() as i32)
738                };
739                (left_card * right_card * selectivity).max(1.0)
740            }
741            JoinType::Left => {
742                // Left join returns at least all left rows
743                let inner_card = self.estimate_join(&JoinOp {
744                    left: join.left.clone(),
745                    right: join.right.clone(),
746                    join_type: JoinType::Inner,
747                    conditions: join.conditions.clone(),
748                });
749                inner_card.max(left_card)
750            }
751            JoinType::Right => {
752                // Right join returns at least all right rows
753                let inner_card = self.estimate_join(&JoinOp {
754                    left: join.left.clone(),
755                    right: join.right.clone(),
756                    join_type: JoinType::Inner,
757                    conditions: join.conditions.clone(),
758                });
759                inner_card.max(right_card)
760            }
761            JoinType::Full => {
762                // Full join returns at least max(left, right)
763                let inner_card = self.estimate_join(&JoinOp {
764                    left: join.left.clone(),
765                    right: join.right.clone(),
766                    join_type: JoinType::Inner,
767                    conditions: join.conditions.clone(),
768                });
769                inner_card.max(left_card.max(right_card))
770            }
771            JoinType::Semi => {
772                // Semi join returns at most left cardinality
773                (left_card * self.default_selectivity).max(1.0)
774            }
775            JoinType::Anti => {
776                // Anti join returns at most left cardinality
777                (left_card * (1.0 - self.default_selectivity)).max(1.0)
778            }
779        }
780    }
781
782    /// Estimates aggregation cardinality.
783    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
784        let input_cardinality = self.estimate(&agg.input);
785
786        if agg.group_by.is_empty() {
787            // Global aggregation - single row
788            1.0
789        } else {
790            // Group by - estimate distinct groups
791            // Assume each group key reduces cardinality by 10
792            let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
793            (input_cardinality / group_reduction).max(1.0)
794        }
795    }
796
797    /// Estimates sort cardinality (same as input).
798    fn estimate_sort(&self, sort: &SortOp) -> f64 {
799        self.estimate(&sort.input)
800    }
801
802    /// Estimates distinct cardinality.
803    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
804        let input_cardinality = self.estimate(&distinct.input);
805        (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
806    }
807
808    /// Estimates limit cardinality.
809    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
810        let input_cardinality = self.estimate(&limit.input);
811        (limit.count as f64).min(input_cardinality)
812    }
813
814    /// Estimates skip cardinality.
815    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
816        let input_cardinality = self.estimate(&skip.input);
817        (input_cardinality - skip.count as f64).max(0.0)
818    }
819
820    /// Estimates vector scan cardinality.
821    ///
822    /// Vector scan returns at most k results (the k nearest neighbors).
823    /// With similarity/distance filters, it may return fewer.
824    fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
825        let base_k = scan.k as f64;
826
827        // Apply filter selectivity if thresholds are specified
828        let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
829            // Assume 70% of results pass threshold filters
830            0.7
831        } else {
832            1.0
833        };
834
835        (base_k * selectivity).max(1.0)
836    }
837
838    /// Estimates vector join cardinality.
839    ///
840    /// Vector join produces up to k results per input row.
841    fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
842        let input_cardinality = self.estimate(&join.input);
843        let k = join.k as f64;
844
845        // Apply filter selectivity if thresholds are specified
846        let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
847            0.7
848        } else {
849            1.0
850        };
851
852        (input_cardinality * k * selectivity).max(1.0)
853    }
854
855    /// Estimates the selectivity of a predicate (0.0 to 1.0).
856    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
857        match expr {
858            LogicalExpression::Binary { left, op, right } => {
859                self.estimate_binary_selectivity(left, *op, right)
860            }
861            LogicalExpression::Unary { op, operand } => {
862                self.estimate_unary_selectivity(*op, operand)
863            }
864            LogicalExpression::Literal(value) => {
865                // Boolean literal
866                if let grafeo_common::types::Value::Bool(b) = value {
867                    if *b { 1.0 } else { 0.0 }
868                } else {
869                    self.default_selectivity
870                }
871            }
872            _ => self.default_selectivity,
873        }
874    }
875
876    /// Estimates binary expression selectivity.
877    fn estimate_binary_selectivity(
878        &self,
879        left: &LogicalExpression,
880        op: BinaryOp,
881        right: &LogicalExpression,
882    ) -> f64 {
883        match op {
884            // Equality - try histogram-based estimation
885            BinaryOp::Eq => {
886                if let Some(selectivity) = self.try_equality_selectivity(left, right) {
887                    return selectivity;
888                }
889                self.selectivity_config.equality
890            }
891            // Inequality is very unselective
892            BinaryOp::Ne => self.selectivity_config.inequality,
893            // Range predicates - use histogram if available
894            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
895                if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
896                    return selectivity;
897                }
898                self.selectivity_config.range
899            }
900            // Logical operators - recursively estimate sub-expressions
901            BinaryOp::And => {
902                let left_sel = self.estimate_selectivity(left);
903                let right_sel = self.estimate_selectivity(right);
904                // AND reduces selectivity (multiply assuming independence)
905                left_sel * right_sel
906            }
907            BinaryOp::Or => {
908                let left_sel = self.estimate_selectivity(left);
909                let right_sel = self.estimate_selectivity(right);
910                // OR: P(A ∪ B) = P(A) + P(B) - P(A ∩ B)
911                // Assuming independence: P(A ∩ B) = P(A) * P(B)
912                (left_sel + right_sel - left_sel * right_sel).min(1.0)
913            }
914            // String operations
915            BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
916                self.selectivity_config.string_ops
917            }
918            // Collection membership
919            BinaryOp::In => self.selectivity_config.membership,
920            // Other operations
921            _ => self.default_selectivity,
922        }
923    }
924
925    /// Tries to estimate equality selectivity using histograms.
926    fn try_equality_selectivity(
927        &self,
928        left: &LogicalExpression,
929        right: &LogicalExpression,
930    ) -> Option<f64> {
931        // Extract property access and literal value
932        let (label, column, value) = self.extract_column_and_value(left, right)?;
933
934        // Get column stats with histogram
935        let stats = self.get_column_stats(&label, &column)?;
936
937        // Try histogram-based estimation
938        if let Some(ref histogram) = stats.histogram {
939            return Some(histogram.equality_selectivity(value));
940        }
941
942        // Fall back to distinct count estimation
943        if stats.distinct_count > 0 {
944            return Some(1.0 / stats.distinct_count as f64);
945        }
946
947        None
948    }
949
950    /// Tries to estimate range selectivity using histograms.
951    fn try_range_selectivity(
952        &self,
953        left: &LogicalExpression,
954        op: BinaryOp,
955        right: &LogicalExpression,
956    ) -> Option<f64> {
957        // Extract property access and literal value
958        let (label, column, value) = self.extract_column_and_value(left, right)?;
959
960        // Get column stats
961        let stats = self.get_column_stats(&label, &column)?;
962
963        // Determine the range based on operator
964        let (lower, upper) = match op {
965            BinaryOp::Lt => (None, Some(value)),
966            BinaryOp::Le => (None, Some(value + f64::EPSILON)),
967            BinaryOp::Gt => (Some(value + f64::EPSILON), None),
968            BinaryOp::Ge => (Some(value), None),
969            _ => return None,
970        };
971
972        // Try histogram-based estimation first
973        if let Some(ref histogram) = stats.histogram {
974            return Some(histogram.range_selectivity(lower, upper));
975        }
976
977        // Fall back to min/max range estimation
978        if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
979            let range = max - min;
980            if range <= 0.0 {
981                return Some(1.0);
982            }
983
984            let effective_lower = lower.unwrap_or(min).max(min);
985            let effective_upper = upper.unwrap_or(max).min(max);
986            let overlap = (effective_upper - effective_lower).max(0.0);
987            return Some((overlap / range).min(1.0).max(0.0));
988        }
989
990        None
991    }
992
993    /// Extracts column information and literal value from a comparison.
994    ///
995    /// Returns (label, column_name, numeric_value) if the expression is
996    /// a comparison between a property access and a numeric literal.
997    fn extract_column_and_value(
998        &self,
999        left: &LogicalExpression,
1000        right: &LogicalExpression,
1001    ) -> Option<(String, String, f64)> {
1002        // Try left as property, right as literal
1003        if let Some(result) = self.try_extract_property_literal(left, right) {
1004            return Some(result);
1005        }
1006
1007        // Try right as property, left as literal
1008        self.try_extract_property_literal(right, left)
1009    }
1010
1011    /// Tries to extract property and literal from a specific ordering.
1012    fn try_extract_property_literal(
1013        &self,
1014        property_expr: &LogicalExpression,
1015        literal_expr: &LogicalExpression,
1016    ) -> Option<(String, String, f64)> {
1017        // Extract property access
1018        let (variable, property) = match property_expr {
1019            LogicalExpression::Property { variable, property } => {
1020                (variable.clone(), property.clone())
1021            }
1022            _ => return None,
1023        };
1024
1025        // Extract numeric literal
1026        let value = match literal_expr {
1027            LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1028            LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1029            _ => return None,
1030        };
1031
1032        // Try to find a label for this variable from table stats
1033        // Use the variable name as a heuristic label lookup
1034        // In practice, the optimizer would track which labels variables are bound to
1035        for label in self.table_stats.keys() {
1036            if let Some(stats) = self.table_stats.get(label)
1037                && stats.columns.contains_key(&property)
1038            {
1039                return Some((label.clone(), property, value));
1040            }
1041        }
1042
1043        // If no stats found but we have the property, return with variable as label
1044        Some((variable, property, value))
1045    }
1046
1047    /// Estimates unary expression selectivity.
1048    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1049        match op {
1050            UnaryOp::Not => 1.0 - self.default_selectivity,
1051            UnaryOp::IsNull => self.selectivity_config.is_null,
1052            UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1053            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
1054        }
1055    }
1056
1057    /// Gets statistics for a column.
1058    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1059        self.table_stats.get(label)?.columns.get(column)
1060    }
1061}
1062
1063impl Default for CardinalityEstimator {
1064    fn default() -> Self {
1065        Self::new()
1066    }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071    use super::*;
1072    use crate::query::plan::{
1073        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1074        ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1075    };
1076    use grafeo_common::types::Value;
1077
1078    #[test]
1079    fn test_node_scan_with_stats() {
1080        let mut estimator = CardinalityEstimator::new();
1081        estimator.add_table_stats("Person", TableStats::new(5000));
1082
1083        let scan = LogicalOperator::NodeScan(NodeScanOp {
1084            variable: "n".to_string(),
1085            label: Some("Person".to_string()),
1086            input: None,
1087        });
1088
1089        let cardinality = estimator.estimate(&scan);
1090        assert!((cardinality - 5000.0).abs() < 0.001);
1091    }
1092
1093    #[test]
1094    fn test_filter_reduces_cardinality() {
1095        let mut estimator = CardinalityEstimator::new();
1096        estimator.add_table_stats("Person", TableStats::new(1000));
1097
1098        let filter = LogicalOperator::Filter(FilterOp {
1099            predicate: LogicalExpression::Binary {
1100                left: Box::new(LogicalExpression::Property {
1101                    variable: "n".to_string(),
1102                    property: "age".to_string(),
1103                }),
1104                op: BinaryOp::Eq,
1105                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1106            },
1107            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1108                variable: "n".to_string(),
1109                label: Some("Person".to_string()),
1110                input: None,
1111            })),
1112        });
1113
1114        let cardinality = estimator.estimate(&filter);
1115        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
1116        assert!(cardinality < 1000.0);
1117        assert!(cardinality >= 1.0);
1118    }
1119
1120    #[test]
1121    fn test_join_cardinality() {
1122        let mut estimator = CardinalityEstimator::new();
1123        estimator.add_table_stats("Person", TableStats::new(1000));
1124        estimator.add_table_stats("Company", TableStats::new(100));
1125
1126        let join = LogicalOperator::Join(JoinOp {
1127            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1128                variable: "p".to_string(),
1129                label: Some("Person".to_string()),
1130                input: None,
1131            })),
1132            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1133                variable: "c".to_string(),
1134                label: Some("Company".to_string()),
1135                input: None,
1136            })),
1137            join_type: JoinType::Inner,
1138            conditions: vec![JoinCondition {
1139                left: LogicalExpression::Property {
1140                    variable: "p".to_string(),
1141                    property: "company_id".to_string(),
1142                },
1143                right: LogicalExpression::Property {
1144                    variable: "c".to_string(),
1145                    property: "id".to_string(),
1146                },
1147            }],
1148        });
1149
1150        let cardinality = estimator.estimate(&join);
1151        // Should be less than cross product
1152        assert!(cardinality < 1000.0 * 100.0);
1153    }
1154
1155    #[test]
1156    fn test_limit_caps_cardinality() {
1157        let mut estimator = CardinalityEstimator::new();
1158        estimator.add_table_stats("Person", TableStats::new(1000));
1159
1160        let limit = LogicalOperator::Limit(LimitOp {
1161            count: 10,
1162            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1163                variable: "n".to_string(),
1164                label: Some("Person".to_string()),
1165                input: None,
1166            })),
1167        });
1168
1169        let cardinality = estimator.estimate(&limit);
1170        assert!((cardinality - 10.0).abs() < 0.001);
1171    }
1172
1173    #[test]
1174    fn test_aggregate_reduces_cardinality() {
1175        let mut estimator = CardinalityEstimator::new();
1176        estimator.add_table_stats("Person", TableStats::new(1000));
1177
1178        // Global aggregation
1179        let global_agg = LogicalOperator::Aggregate(AggregateOp {
1180            group_by: vec![],
1181            aggregates: vec![],
1182            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1183                variable: "n".to_string(),
1184                label: Some("Person".to_string()),
1185                input: None,
1186            })),
1187            having: None,
1188        });
1189
1190        let cardinality = estimator.estimate(&global_agg);
1191        assert!((cardinality - 1.0).abs() < 0.001);
1192
1193        // Group by aggregation
1194        let group_agg = LogicalOperator::Aggregate(AggregateOp {
1195            group_by: vec![LogicalExpression::Property {
1196                variable: "n".to_string(),
1197                property: "city".to_string(),
1198            }],
1199            aggregates: vec![],
1200            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1201                variable: "n".to_string(),
1202                label: Some("Person".to_string()),
1203                input: None,
1204            })),
1205            having: None,
1206        });
1207
1208        let cardinality = estimator.estimate(&group_agg);
1209        // Should be less than input
1210        assert!(cardinality < 1000.0);
1211    }
1212
1213    #[test]
1214    fn test_node_scan_without_stats() {
1215        let estimator = CardinalityEstimator::new();
1216
1217        let scan = LogicalOperator::NodeScan(NodeScanOp {
1218            variable: "n".to_string(),
1219            label: Some("Unknown".to_string()),
1220            input: None,
1221        });
1222
1223        let cardinality = estimator.estimate(&scan);
1224        // Should return default (1000)
1225        assert!((cardinality - 1000.0).abs() < 0.001);
1226    }
1227
1228    #[test]
1229    fn test_node_scan_no_label() {
1230        let estimator = CardinalityEstimator::new();
1231
1232        let scan = LogicalOperator::NodeScan(NodeScanOp {
1233            variable: "n".to_string(),
1234            label: None,
1235            input: None,
1236        });
1237
1238        let cardinality = estimator.estimate(&scan);
1239        // Should scan all nodes (default)
1240        assert!((cardinality - 1000.0).abs() < 0.001);
1241    }
1242
1243    #[test]
1244    fn test_filter_inequality_selectivity() {
1245        let mut estimator = CardinalityEstimator::new();
1246        estimator.add_table_stats("Person", TableStats::new(1000));
1247
1248        let filter = LogicalOperator::Filter(FilterOp {
1249            predicate: LogicalExpression::Binary {
1250                left: Box::new(LogicalExpression::Property {
1251                    variable: "n".to_string(),
1252                    property: "age".to_string(),
1253                }),
1254                op: BinaryOp::Ne,
1255                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1256            },
1257            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1258                variable: "n".to_string(),
1259                label: Some("Person".to_string()),
1260                input: None,
1261            })),
1262        });
1263
1264        let cardinality = estimator.estimate(&filter);
1265        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
1266        assert!(cardinality > 900.0);
1267    }
1268
1269    #[test]
1270    fn test_filter_range_selectivity() {
1271        let mut estimator = CardinalityEstimator::new();
1272        estimator.add_table_stats("Person", TableStats::new(1000));
1273
1274        let filter = LogicalOperator::Filter(FilterOp {
1275            predicate: LogicalExpression::Binary {
1276                left: Box::new(LogicalExpression::Property {
1277                    variable: "n".to_string(),
1278                    property: "age".to_string(),
1279                }),
1280                op: BinaryOp::Gt,
1281                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1282            },
1283            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1284                variable: "n".to_string(),
1285                label: Some("Person".to_string()),
1286                input: None,
1287            })),
1288        });
1289
1290        let cardinality = estimator.estimate(&filter);
1291        // Range selectivity is 0.33, so 1000 * 0.33 = 330
1292        assert!(cardinality < 500.0);
1293        assert!(cardinality > 100.0);
1294    }
1295
1296    #[test]
1297    fn test_filter_and_selectivity() {
1298        let mut estimator = CardinalityEstimator::new();
1299        estimator.add_table_stats("Person", TableStats::new(1000));
1300
1301        // Test AND with two equality predicates
1302        // Each equality has selectivity 0.01, so AND gives 0.01 * 0.01 = 0.0001
1303        let filter = LogicalOperator::Filter(FilterOp {
1304            predicate: LogicalExpression::Binary {
1305                left: Box::new(LogicalExpression::Binary {
1306                    left: Box::new(LogicalExpression::Property {
1307                        variable: "n".to_string(),
1308                        property: "city".to_string(),
1309                    }),
1310                    op: BinaryOp::Eq,
1311                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1312                }),
1313                op: BinaryOp::And,
1314                right: Box::new(LogicalExpression::Binary {
1315                    left: Box::new(LogicalExpression::Property {
1316                        variable: "n".to_string(),
1317                        property: "age".to_string(),
1318                    }),
1319                    op: BinaryOp::Eq,
1320                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1321                }),
1322            },
1323            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1324                variable: "n".to_string(),
1325                label: Some("Person".to_string()),
1326                input: None,
1327            })),
1328        });
1329
1330        let cardinality = estimator.estimate(&filter);
1331        // AND reduces selectivity (multiply): 0.01 * 0.01 = 0.0001
1332        // 1000 * 0.0001 = 0.1, min is 1.0
1333        assert!(cardinality < 100.0);
1334        assert!(cardinality >= 1.0);
1335    }
1336
1337    #[test]
1338    fn test_filter_or_selectivity() {
1339        let mut estimator = CardinalityEstimator::new();
1340        estimator.add_table_stats("Person", TableStats::new(1000));
1341
1342        // Test OR with two equality predicates
1343        // Each equality has selectivity 0.01
1344        // OR gives: 0.01 + 0.01 - (0.01 * 0.01) = 0.0199
1345        let filter = LogicalOperator::Filter(FilterOp {
1346            predicate: LogicalExpression::Binary {
1347                left: Box::new(LogicalExpression::Binary {
1348                    left: Box::new(LogicalExpression::Property {
1349                        variable: "n".to_string(),
1350                        property: "city".to_string(),
1351                    }),
1352                    op: BinaryOp::Eq,
1353                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1354                }),
1355                op: BinaryOp::Or,
1356                right: Box::new(LogicalExpression::Binary {
1357                    left: Box::new(LogicalExpression::Property {
1358                        variable: "n".to_string(),
1359                        property: "city".to_string(),
1360                    }),
1361                    op: BinaryOp::Eq,
1362                    right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1363                }),
1364            },
1365            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1366                variable: "n".to_string(),
1367                label: Some("Person".to_string()),
1368                input: None,
1369            })),
1370        });
1371
1372        let cardinality = estimator.estimate(&filter);
1373        // OR: 0.01 + 0.01 - 0.0001 ≈ 0.0199, so 1000 * 0.0199 ≈ 19.9
1374        assert!(cardinality < 100.0);
1375        assert!(cardinality >= 1.0);
1376    }
1377
1378    #[test]
1379    fn test_filter_literal_true() {
1380        let mut estimator = CardinalityEstimator::new();
1381        estimator.add_table_stats("Person", TableStats::new(1000));
1382
1383        let filter = LogicalOperator::Filter(FilterOp {
1384            predicate: LogicalExpression::Literal(Value::Bool(true)),
1385            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1386                variable: "n".to_string(),
1387                label: Some("Person".to_string()),
1388                input: None,
1389            })),
1390        });
1391
1392        let cardinality = estimator.estimate(&filter);
1393        // Literal true has selectivity 1.0
1394        assert!((cardinality - 1000.0).abs() < 0.001);
1395    }
1396
1397    #[test]
1398    fn test_filter_literal_false() {
1399        let mut estimator = CardinalityEstimator::new();
1400        estimator.add_table_stats("Person", TableStats::new(1000));
1401
1402        let filter = LogicalOperator::Filter(FilterOp {
1403            predicate: LogicalExpression::Literal(Value::Bool(false)),
1404            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1405                variable: "n".to_string(),
1406                label: Some("Person".to_string()),
1407                input: None,
1408            })),
1409        });
1410
1411        let cardinality = estimator.estimate(&filter);
1412        // Literal false has selectivity 0.0, but min is 1.0
1413        assert!((cardinality - 1.0).abs() < 0.001);
1414    }
1415
1416    #[test]
1417    fn test_unary_not_selectivity() {
1418        let mut estimator = CardinalityEstimator::new();
1419        estimator.add_table_stats("Person", TableStats::new(1000));
1420
1421        let filter = LogicalOperator::Filter(FilterOp {
1422            predicate: LogicalExpression::Unary {
1423                op: UnaryOp::Not,
1424                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1425            },
1426            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1427                variable: "n".to_string(),
1428                label: Some("Person".to_string()),
1429                input: None,
1430            })),
1431        });
1432
1433        let cardinality = estimator.estimate(&filter);
1434        // NOT inverts selectivity
1435        assert!(cardinality < 1000.0);
1436    }
1437
1438    #[test]
1439    fn test_unary_is_null_selectivity() {
1440        let mut estimator = CardinalityEstimator::new();
1441        estimator.add_table_stats("Person", TableStats::new(1000));
1442
1443        let filter = LogicalOperator::Filter(FilterOp {
1444            predicate: LogicalExpression::Unary {
1445                op: UnaryOp::IsNull,
1446                operand: Box::new(LogicalExpression::Variable("x".to_string())),
1447            },
1448            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1449                variable: "n".to_string(),
1450                label: Some("Person".to_string()),
1451                input: None,
1452            })),
1453        });
1454
1455        let cardinality = estimator.estimate(&filter);
1456        // IS NULL has selectivity 0.05
1457        assert!(cardinality < 100.0);
1458    }
1459
1460    #[test]
1461    fn test_expand_cardinality() {
1462        let mut estimator = CardinalityEstimator::new();
1463        estimator.add_table_stats("Person", TableStats::new(100));
1464
1465        let expand = LogicalOperator::Expand(ExpandOp {
1466            from_variable: "a".to_string(),
1467            to_variable: "b".to_string(),
1468            edge_variable: None,
1469            direction: ExpandDirection::Outgoing,
1470            edge_types: vec![],
1471            min_hops: 1,
1472            max_hops: Some(1),
1473            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1474                variable: "a".to_string(),
1475                label: Some("Person".to_string()),
1476                input: None,
1477            })),
1478            path_alias: None,
1479            path_mode: PathMode::Walk,
1480        });
1481
1482        let cardinality = estimator.estimate(&expand);
1483        // Expand multiplies by fanout (10)
1484        assert!(cardinality > 100.0);
1485    }
1486
1487    #[test]
1488    fn test_expand_with_edge_type_filter() {
1489        let mut estimator = CardinalityEstimator::new();
1490        estimator.add_table_stats("Person", TableStats::new(100));
1491
1492        let expand = LogicalOperator::Expand(ExpandOp {
1493            from_variable: "a".to_string(),
1494            to_variable: "b".to_string(),
1495            edge_variable: None,
1496            direction: ExpandDirection::Outgoing,
1497            edge_types: vec!["KNOWS".to_string()],
1498            min_hops: 1,
1499            max_hops: Some(1),
1500            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1501                variable: "a".to_string(),
1502                label: Some("Person".to_string()),
1503                input: None,
1504            })),
1505            path_alias: None,
1506            path_mode: PathMode::Walk,
1507        });
1508
1509        let cardinality = estimator.estimate(&expand);
1510        // With edge type, fanout is reduced by half
1511        assert!(cardinality > 100.0);
1512    }
1513
1514    #[test]
1515    fn test_expand_variable_length() {
1516        let mut estimator = CardinalityEstimator::new();
1517        estimator.add_table_stats("Person", TableStats::new(100));
1518
1519        let expand = LogicalOperator::Expand(ExpandOp {
1520            from_variable: "a".to_string(),
1521            to_variable: "b".to_string(),
1522            edge_variable: None,
1523            direction: ExpandDirection::Outgoing,
1524            edge_types: vec![],
1525            min_hops: 1,
1526            max_hops: Some(3),
1527            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1528                variable: "a".to_string(),
1529                label: Some("Person".to_string()),
1530                input: None,
1531            })),
1532            path_alias: None,
1533            path_mode: PathMode::Walk,
1534        });
1535
1536        let cardinality = estimator.estimate(&expand);
1537        // Variable length path has much higher cardinality
1538        assert!(cardinality > 500.0);
1539    }
1540
1541    #[test]
1542    fn test_join_cross_product() {
1543        let mut estimator = CardinalityEstimator::new();
1544        estimator.add_table_stats("Person", TableStats::new(100));
1545        estimator.add_table_stats("Company", TableStats::new(50));
1546
1547        let join = LogicalOperator::Join(JoinOp {
1548            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1549                variable: "p".to_string(),
1550                label: Some("Person".to_string()),
1551                input: None,
1552            })),
1553            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1554                variable: "c".to_string(),
1555                label: Some("Company".to_string()),
1556                input: None,
1557            })),
1558            join_type: JoinType::Cross,
1559            conditions: vec![],
1560        });
1561
1562        let cardinality = estimator.estimate(&join);
1563        // Cross join = 100 * 50 = 5000
1564        assert!((cardinality - 5000.0).abs() < 0.001);
1565    }
1566
1567    #[test]
1568    fn test_join_left_outer() {
1569        let mut estimator = CardinalityEstimator::new();
1570        estimator.add_table_stats("Person", TableStats::new(1000));
1571        estimator.add_table_stats("Company", TableStats::new(10));
1572
1573        let join = LogicalOperator::Join(JoinOp {
1574            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1575                variable: "p".to_string(),
1576                label: Some("Person".to_string()),
1577                input: None,
1578            })),
1579            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1580                variable: "c".to_string(),
1581                label: Some("Company".to_string()),
1582                input: None,
1583            })),
1584            join_type: JoinType::Left,
1585            conditions: vec![JoinCondition {
1586                left: LogicalExpression::Variable("p".to_string()),
1587                right: LogicalExpression::Variable("c".to_string()),
1588            }],
1589        });
1590
1591        let cardinality = estimator.estimate(&join);
1592        // Left join returns at least all left rows
1593        assert!(cardinality >= 1000.0);
1594    }
1595
1596    #[test]
1597    fn test_join_semi() {
1598        let mut estimator = CardinalityEstimator::new();
1599        estimator.add_table_stats("Person", TableStats::new(1000));
1600        estimator.add_table_stats("Company", TableStats::new(100));
1601
1602        let join = LogicalOperator::Join(JoinOp {
1603            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1604                variable: "p".to_string(),
1605                label: Some("Person".to_string()),
1606                input: None,
1607            })),
1608            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1609                variable: "c".to_string(),
1610                label: Some("Company".to_string()),
1611                input: None,
1612            })),
1613            join_type: JoinType::Semi,
1614            conditions: vec![],
1615        });
1616
1617        let cardinality = estimator.estimate(&join);
1618        // Semi join returns at most left cardinality
1619        assert!(cardinality <= 1000.0);
1620    }
1621
1622    #[test]
1623    fn test_join_anti() {
1624        let mut estimator = CardinalityEstimator::new();
1625        estimator.add_table_stats("Person", TableStats::new(1000));
1626        estimator.add_table_stats("Company", TableStats::new(100));
1627
1628        let join = LogicalOperator::Join(JoinOp {
1629            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1630                variable: "p".to_string(),
1631                label: Some("Person".to_string()),
1632                input: None,
1633            })),
1634            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1635                variable: "c".to_string(),
1636                label: Some("Company".to_string()),
1637                input: None,
1638            })),
1639            join_type: JoinType::Anti,
1640            conditions: vec![],
1641        });
1642
1643        let cardinality = estimator.estimate(&join);
1644        // Anti join returns at most left cardinality
1645        assert!(cardinality <= 1000.0);
1646        assert!(cardinality >= 1.0);
1647    }
1648
1649    #[test]
1650    fn test_project_preserves_cardinality() {
1651        let mut estimator = CardinalityEstimator::new();
1652        estimator.add_table_stats("Person", TableStats::new(1000));
1653
1654        let project = LogicalOperator::Project(ProjectOp {
1655            projections: vec![Projection {
1656                expression: LogicalExpression::Variable("n".to_string()),
1657                alias: None,
1658            }],
1659            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1660                variable: "n".to_string(),
1661                label: Some("Person".to_string()),
1662                input: None,
1663            })),
1664        });
1665
1666        let cardinality = estimator.estimate(&project);
1667        assert!((cardinality - 1000.0).abs() < 0.001);
1668    }
1669
1670    #[test]
1671    fn test_sort_preserves_cardinality() {
1672        let mut estimator = CardinalityEstimator::new();
1673        estimator.add_table_stats("Person", TableStats::new(1000));
1674
1675        let sort = LogicalOperator::Sort(SortOp {
1676            keys: vec![SortKey {
1677                expression: LogicalExpression::Variable("n".to_string()),
1678                order: SortOrder::Ascending,
1679            }],
1680            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1681                variable: "n".to_string(),
1682                label: Some("Person".to_string()),
1683                input: None,
1684            })),
1685        });
1686
1687        let cardinality = estimator.estimate(&sort);
1688        assert!((cardinality - 1000.0).abs() < 0.001);
1689    }
1690
1691    #[test]
1692    fn test_distinct_reduces_cardinality() {
1693        let mut estimator = CardinalityEstimator::new();
1694        estimator.add_table_stats("Person", TableStats::new(1000));
1695
1696        let distinct = LogicalOperator::Distinct(DistinctOp {
1697            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1698                variable: "n".to_string(),
1699                label: Some("Person".to_string()),
1700                input: None,
1701            })),
1702            columns: None,
1703        });
1704
1705        let cardinality = estimator.estimate(&distinct);
1706        // Distinct assumes 50% unique
1707        assert!((cardinality - 500.0).abs() < 0.001);
1708    }
1709
1710    #[test]
1711    fn test_skip_reduces_cardinality() {
1712        let mut estimator = CardinalityEstimator::new();
1713        estimator.add_table_stats("Person", TableStats::new(1000));
1714
1715        let skip = LogicalOperator::Skip(SkipOp {
1716            count: 100,
1717            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1718                variable: "n".to_string(),
1719                label: Some("Person".to_string()),
1720                input: None,
1721            })),
1722        });
1723
1724        let cardinality = estimator.estimate(&skip);
1725        assert!((cardinality - 900.0).abs() < 0.001);
1726    }
1727
1728    #[test]
1729    fn test_return_preserves_cardinality() {
1730        let mut estimator = CardinalityEstimator::new();
1731        estimator.add_table_stats("Person", TableStats::new(1000));
1732
1733        let ret = LogicalOperator::Return(ReturnOp {
1734            items: vec![ReturnItem {
1735                expression: LogicalExpression::Variable("n".to_string()),
1736                alias: None,
1737            }],
1738            distinct: false,
1739            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1740                variable: "n".to_string(),
1741                label: Some("Person".to_string()),
1742                input: None,
1743            })),
1744        });
1745
1746        let cardinality = estimator.estimate(&ret);
1747        assert!((cardinality - 1000.0).abs() < 0.001);
1748    }
1749
1750    #[test]
1751    fn test_empty_cardinality() {
1752        let estimator = CardinalityEstimator::new();
1753        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1754        assert!((cardinality).abs() < 0.001);
1755    }
1756
1757    #[test]
1758    fn test_table_stats_with_column() {
1759        let stats = TableStats::new(1000).with_column(
1760            "age",
1761            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1762        );
1763
1764        assert_eq!(stats.row_count, 1000);
1765        let col = stats.columns.get("age").unwrap();
1766        assert_eq!(col.distinct_count, 50);
1767        assert_eq!(col.null_count, 10);
1768        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1769        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1770    }
1771
1772    #[test]
1773    fn test_estimator_default() {
1774        let estimator = CardinalityEstimator::default();
1775        let scan = LogicalOperator::NodeScan(NodeScanOp {
1776            variable: "n".to_string(),
1777            label: None,
1778            input: None,
1779        });
1780        let cardinality = estimator.estimate(&scan);
1781        assert!((cardinality - 1000.0).abs() < 0.001);
1782    }
1783
1784    #[test]
1785    fn test_set_avg_fanout() {
1786        let mut estimator = CardinalityEstimator::new();
1787        estimator.add_table_stats("Person", TableStats::new(100));
1788        estimator.set_avg_fanout(5.0);
1789
1790        let expand = LogicalOperator::Expand(ExpandOp {
1791            from_variable: "a".to_string(),
1792            to_variable: "b".to_string(),
1793            edge_variable: None,
1794            direction: ExpandDirection::Outgoing,
1795            edge_types: vec![],
1796            min_hops: 1,
1797            max_hops: Some(1),
1798            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1799                variable: "a".to_string(),
1800                label: Some("Person".to_string()),
1801                input: None,
1802            })),
1803            path_alias: None,
1804            path_mode: PathMode::Walk,
1805        });
1806
1807        let cardinality = estimator.estimate(&expand);
1808        // With fanout 5: 100 * 5 = 500
1809        assert!((cardinality - 500.0).abs() < 0.001);
1810    }
1811
1812    #[test]
1813    fn test_multiple_group_by_keys_reduce_cardinality() {
1814        // The current implementation uses a simplified model where more group by keys
1815        // results in greater reduction (dividing by 10^num_keys). This is a simplification
1816        // that works for most cases where group by keys are correlated.
1817        let mut estimator = CardinalityEstimator::new();
1818        estimator.add_table_stats("Person", TableStats::new(10000));
1819
1820        let single_group = LogicalOperator::Aggregate(AggregateOp {
1821            group_by: vec![LogicalExpression::Property {
1822                variable: "n".to_string(),
1823                property: "city".to_string(),
1824            }],
1825            aggregates: vec![],
1826            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1827                variable: "n".to_string(),
1828                label: Some("Person".to_string()),
1829                input: None,
1830            })),
1831            having: None,
1832        });
1833
1834        let multi_group = LogicalOperator::Aggregate(AggregateOp {
1835            group_by: vec![
1836                LogicalExpression::Property {
1837                    variable: "n".to_string(),
1838                    property: "city".to_string(),
1839                },
1840                LogicalExpression::Property {
1841                    variable: "n".to_string(),
1842                    property: "country".to_string(),
1843                },
1844            ],
1845            aggregates: vec![],
1846            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1847                variable: "n".to_string(),
1848                label: Some("Person".to_string()),
1849                input: None,
1850            })),
1851            having: None,
1852        });
1853
1854        let single_card = estimator.estimate(&single_group);
1855        let multi_card = estimator.estimate(&multi_group);
1856
1857        // Both should reduce cardinality from input
1858        assert!(single_card < 10000.0);
1859        assert!(multi_card < 10000.0);
1860        // Both should be at least 1
1861        assert!(single_card >= 1.0);
1862        assert!(multi_card >= 1.0);
1863    }
1864
1865    // ============= Histogram Tests =============
1866
1867    #[test]
1868    fn test_histogram_build_uniform() {
1869        // Build histogram from uniformly distributed data
1870        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1871        let histogram = EquiDepthHistogram::build(&values, 10);
1872
1873        assert_eq!(histogram.num_buckets(), 10);
1874        assert_eq!(histogram.total_rows(), 100);
1875
1876        // Each bucket should have approximately 10 rows
1877        for bucket in histogram.buckets() {
1878            assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1879        }
1880    }
1881
1882    #[test]
1883    fn test_histogram_build_skewed() {
1884        // Build histogram from skewed data (many small values, few large)
1885        let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1886        values.extend((0..20).map(|i| 1000.0 + i as f64));
1887        let histogram = EquiDepthHistogram::build(&values, 5);
1888
1889        assert_eq!(histogram.num_buckets(), 5);
1890        assert_eq!(histogram.total_rows(), 100);
1891
1892        // Each bucket should have ~20 rows despite skewed data
1893        for bucket in histogram.buckets() {
1894            assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1895        }
1896    }
1897
1898    #[test]
1899    fn test_histogram_range_selectivity_full() {
1900        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1901        let histogram = EquiDepthHistogram::build(&values, 10);
1902
1903        // Full range should have selectivity ~1.0
1904        let selectivity = histogram.range_selectivity(None, None);
1905        assert!((selectivity - 1.0).abs() < 0.01);
1906    }
1907
1908    #[test]
1909    fn test_histogram_range_selectivity_half() {
1910        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1911        let histogram = EquiDepthHistogram::build(&values, 10);
1912
1913        // Values >= 50 should be ~50% (half the data)
1914        let selectivity = histogram.range_selectivity(Some(50.0), None);
1915        assert!(selectivity > 0.4 && selectivity < 0.6);
1916    }
1917
1918    #[test]
1919    fn test_histogram_range_selectivity_quarter() {
1920        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1921        let histogram = EquiDepthHistogram::build(&values, 10);
1922
1923        // Values < 25 should be ~25%
1924        let selectivity = histogram.range_selectivity(None, Some(25.0));
1925        assert!(selectivity > 0.2 && selectivity < 0.3);
1926    }
1927
1928    #[test]
1929    fn test_histogram_equality_selectivity() {
1930        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1931        let histogram = EquiDepthHistogram::build(&values, 10);
1932
1933        // Equality on 100 distinct values should be ~1%
1934        let selectivity = histogram.equality_selectivity(50.0);
1935        assert!(selectivity > 0.005 && selectivity < 0.02);
1936    }
1937
1938    #[test]
1939    fn test_histogram_empty() {
1940        let histogram = EquiDepthHistogram::build(&[], 10);
1941
1942        assert_eq!(histogram.num_buckets(), 0);
1943        assert_eq!(histogram.total_rows(), 0);
1944
1945        // Default selectivity for empty histogram
1946        let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1947        assert!((selectivity - 0.33).abs() < 0.01);
1948    }
1949
1950    #[test]
1951    fn test_histogram_bucket_overlap() {
1952        let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1953
1954        // Full overlap
1955        assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1956
1957        // Half overlap (lower half)
1958        assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1959
1960        // Half overlap (upper half)
1961        assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1962
1963        // No overlap (below)
1964        assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1965
1966        // No overlap (above)
1967        assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1968    }
1969
1970    #[test]
1971    fn test_column_stats_from_values() {
1972        let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
1973        let stats = ColumnStats::from_values(values, 4);
1974
1975        assert_eq!(stats.distinct_count, 5); // 10, 20, 30, 40, 50
1976        assert!(stats.min_value.is_some());
1977        assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
1978        assert!(stats.max_value.is_some());
1979        assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
1980        assert!(stats.histogram.is_some());
1981    }
1982
1983    #[test]
1984    fn test_column_stats_with_histogram_builder() {
1985        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1986        let histogram = EquiDepthHistogram::build(&values, 10);
1987
1988        let stats = ColumnStats::new(100)
1989            .with_range(0.0, 99.0)
1990            .with_histogram(histogram);
1991
1992        assert!(stats.histogram.is_some());
1993        assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
1994    }
1995
1996    #[test]
1997    fn test_filter_with_histogram_stats() {
1998        let mut estimator = CardinalityEstimator::new();
1999
2000        // Create stats with histogram for age column
2001        let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2002        let histogram = EquiDepthHistogram::build(&age_values, 10);
2003        let age_stats = ColumnStats::new(62)
2004            .with_range(18.0, 79.0)
2005            .with_histogram(histogram);
2006
2007        estimator.add_table_stats(
2008            "Person",
2009            TableStats::new(1000).with_column("age", age_stats),
2010        );
2011
2012        // Filter: age > 50
2013        // Age range is 18-79, so >50 is about (79-50)/(79-18) = 29/61 ≈ 47.5%
2014        let filter = LogicalOperator::Filter(FilterOp {
2015            predicate: LogicalExpression::Binary {
2016                left: Box::new(LogicalExpression::Property {
2017                    variable: "n".to_string(),
2018                    property: "age".to_string(),
2019                }),
2020                op: BinaryOp::Gt,
2021                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2022            },
2023            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2024                variable: "n".to_string(),
2025                label: Some("Person".to_string()),
2026                input: None,
2027            })),
2028        });
2029
2030        let cardinality = estimator.estimate(&filter);
2031
2032        // With histogram, should get more accurate estimate than default 0.33
2033        // Expected: ~47.5% of 1000 = ~475
2034        assert!(cardinality > 300.0 && cardinality < 600.0);
2035    }
2036
2037    #[test]
2038    fn test_filter_equality_with_histogram() {
2039        let mut estimator = CardinalityEstimator::new();
2040
2041        // Create stats with histogram
2042        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2043        let histogram = EquiDepthHistogram::build(&values, 10);
2044        let stats = ColumnStats::new(100)
2045            .with_range(0.0, 99.0)
2046            .with_histogram(histogram);
2047
2048        estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2049
2050        // Filter: value = 50
2051        let filter = LogicalOperator::Filter(FilterOp {
2052            predicate: LogicalExpression::Binary {
2053                left: Box::new(LogicalExpression::Property {
2054                    variable: "d".to_string(),
2055                    property: "value".to_string(),
2056                }),
2057                op: BinaryOp::Eq,
2058                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2059            },
2060            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2061                variable: "d".to_string(),
2062                label: Some("Data".to_string()),
2063                input: None,
2064            })),
2065        });
2066
2067        let cardinality = estimator.estimate(&filter);
2068
2069        // With 100 distinct values, selectivity should be ~1%
2070        // 1000 * 0.01 = 10
2071        assert!((1.0..50.0).contains(&cardinality));
2072    }
2073
2074    #[test]
2075    fn test_histogram_min_max() {
2076        let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2077        let histogram = EquiDepthHistogram::build(&values, 2);
2078
2079        assert_eq!(histogram.min_value(), Some(5.0));
2080        // Max is the upper bound of the last bucket
2081        assert!(histogram.max_value().is_some());
2082    }
2083
2084    // ==================== SelectivityConfig Tests ====================
2085
2086    #[test]
2087    fn test_selectivity_config_defaults() {
2088        let config = SelectivityConfig::new();
2089        assert!((config.default - 0.1).abs() < f64::EPSILON);
2090        assert!((config.equality - 0.01).abs() < f64::EPSILON);
2091        assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2092        assert!((config.range - 0.33).abs() < f64::EPSILON);
2093        assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2094        assert!((config.membership - 0.1).abs() < f64::EPSILON);
2095        assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2096        assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2097        assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2098    }
2099
2100    #[test]
2101    fn test_custom_selectivity_config() {
2102        let config = SelectivityConfig {
2103            equality: 0.05,
2104            range: 0.25,
2105            ..SelectivityConfig::new()
2106        };
2107        let estimator = CardinalityEstimator::with_selectivity_config(config);
2108        assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2109        assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2110    }
2111
2112    #[test]
2113    fn test_custom_selectivity_affects_estimation() {
2114        // Default: equality = 0.01 → 1000 * 0.01 = 10
2115        let mut default_est = CardinalityEstimator::new();
2116        default_est.add_table_stats("Person", TableStats::new(1000));
2117
2118        let filter = LogicalOperator::Filter(FilterOp {
2119            predicate: LogicalExpression::Binary {
2120                left: Box::new(LogicalExpression::Property {
2121                    variable: "n".to_string(),
2122                    property: "name".to_string(),
2123                }),
2124                op: BinaryOp::Eq,
2125                right: Box::new(LogicalExpression::Literal(Value::String("Alice".into()))),
2126            },
2127            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2128                variable: "n".to_string(),
2129                label: Some("Person".to_string()),
2130                input: None,
2131            })),
2132        });
2133
2134        let default_card = default_est.estimate(&filter);
2135
2136        // Custom: equality = 0.2 → 1000 * 0.2 = 200
2137        let config = SelectivityConfig {
2138            equality: 0.2,
2139            ..SelectivityConfig::new()
2140        };
2141        let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2142        custom_est.add_table_stats("Person", TableStats::new(1000));
2143
2144        let custom_card = custom_est.estimate(&filter);
2145
2146        assert!(custom_card > default_card);
2147        assert!((custom_card - 200.0).abs() < 1.0);
2148    }
2149
2150    #[test]
2151    fn test_custom_range_selectivity() {
2152        let config = SelectivityConfig {
2153            range: 0.5,
2154            ..SelectivityConfig::new()
2155        };
2156        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2157        estimator.add_table_stats("Person", TableStats::new(1000));
2158
2159        let filter = LogicalOperator::Filter(FilterOp {
2160            predicate: LogicalExpression::Binary {
2161                left: Box::new(LogicalExpression::Property {
2162                    variable: "n".to_string(),
2163                    property: "age".to_string(),
2164                }),
2165                op: BinaryOp::Gt,
2166                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2167            },
2168            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2169                variable: "n".to_string(),
2170                label: Some("Person".to_string()),
2171                input: None,
2172            })),
2173        });
2174
2175        let cardinality = estimator.estimate(&filter);
2176        // 1000 * 0.5 = 500
2177        assert!((cardinality - 500.0).abs() < 1.0);
2178    }
2179
2180    #[test]
2181    fn test_custom_distinct_fraction() {
2182        let config = SelectivityConfig {
2183            distinct_fraction: 0.8,
2184            ..SelectivityConfig::new()
2185        };
2186        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2187        estimator.add_table_stats("Person", TableStats::new(1000));
2188
2189        let distinct = LogicalOperator::Distinct(DistinctOp {
2190            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2191                variable: "n".to_string(),
2192                label: Some("Person".to_string()),
2193                input: None,
2194            })),
2195            columns: None,
2196        });
2197
2198        let cardinality = estimator.estimate(&distinct);
2199        // 1000 * 0.8 = 800
2200        assert!((cardinality - 800.0).abs() < 1.0);
2201    }
2202
2203    // ==================== EstimationLog Tests ====================
2204
2205    #[test]
2206    fn test_estimation_log_basic() {
2207        let mut log = EstimationLog::new(10.0);
2208        log.record("NodeScan(Person)", 1000.0, 1200.0);
2209        log.record("Filter(age > 30)", 100.0, 90.0);
2210
2211        assert_eq!(log.entries().len(), 2);
2212        assert!(!log.should_replan()); // 1.2x and 0.9x are within 10x threshold
2213    }
2214
2215    #[test]
2216    fn test_estimation_log_triggers_replan() {
2217        let mut log = EstimationLog::new(10.0);
2218        log.record("NodeScan(Person)", 100.0, 5000.0); // 50x underestimate
2219
2220        assert!(log.should_replan());
2221    }
2222
2223    #[test]
2224    fn test_estimation_log_overestimate_triggers_replan() {
2225        let mut log = EstimationLog::new(5.0);
2226        log.record("Filter", 1000.0, 100.0); // 10x overestimate → ratio = 0.1
2227
2228        assert!(log.should_replan()); // 0.1 < 1/5 = 0.2
2229    }
2230
2231    #[test]
2232    fn test_estimation_entry_error_ratio() {
2233        let entry = EstimationEntry {
2234            operator: "test".into(),
2235            estimated: 100.0,
2236            actual: 200.0,
2237        };
2238        assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2239
2240        let perfect = EstimationEntry {
2241            operator: "test".into(),
2242            estimated: 100.0,
2243            actual: 100.0,
2244        };
2245        assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2246
2247        let zero_est = EstimationEntry {
2248            operator: "test".into(),
2249            estimated: 0.0,
2250            actual: 0.0,
2251        };
2252        assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2253    }
2254
2255    #[test]
2256    fn test_estimation_log_max_error_ratio() {
2257        let mut log = EstimationLog::new(10.0);
2258        log.record("A", 100.0, 300.0); // 3x
2259        log.record("B", 100.0, 50.0); // 2x (normalized: 1/0.5 = 2)
2260        log.record("C", 100.0, 100.0); // 1x
2261
2262        assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2263    }
2264
2265    #[test]
2266    fn test_estimation_log_clear() {
2267        let mut log = EstimationLog::new(10.0);
2268        log.record("A", 100.0, 100.0);
2269        assert_eq!(log.entries().len(), 1);
2270
2271        log.clear();
2272        assert!(log.entries().is_empty());
2273        assert!(!log.should_replan());
2274    }
2275
2276    #[test]
2277    fn test_create_estimation_log() {
2278        let log = CardinalityEstimator::create_estimation_log();
2279        assert!(log.entries().is_empty());
2280        assert!(!log.should_replan());
2281    }
2282}