Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4//!
5//! # Equi-Depth Histograms
6//!
7//! This module provides equi-depth histogram support for accurate selectivity
8//! estimation of range predicates. Unlike equi-width histograms that divide
9//! the value range into equal-sized buckets, equi-depth histograms divide
10//! the data into buckets with approximately equal numbers of rows.
11//!
12//! Benefits:
13//! - Better estimates for skewed data distributions
14//! - More accurate range selectivity than assuming uniform distribution
15//! - Adaptive to actual data characteristics
16
17use crate::query::plan::{
18    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19    LogicalExpression, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, SkipOp, SortOp,
20    UnaryOp, VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24/// A bucket in an equi-depth histogram.
25///
26/// Each bucket represents a range of values and the frequency of rows
27/// falling within that range. In an equi-depth histogram, all buckets
28/// contain approximately the same number of rows.
29#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31    /// Lower bound of the bucket (inclusive).
32    pub lower_bound: f64,
33    /// Upper bound of the bucket (exclusive, except for the last bucket).
34    pub upper_bound: f64,
35    /// Number of rows in this bucket.
36    pub frequency: u64,
37    /// Number of distinct values in this bucket.
38    pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42    /// Creates a new histogram bucket.
43    #[must_use]
44    pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45        Self {
46            lower_bound,
47            upper_bound,
48            frequency,
49            distinct_count,
50        }
51    }
52
53    /// Returns the width of this bucket.
54    #[must_use]
55    pub fn width(&self) -> f64 {
56        self.upper_bound - self.lower_bound
57    }
58
59    /// Checks if a value falls within this bucket.
60    #[must_use]
61    pub fn contains(&self, value: f64) -> bool {
62        value >= self.lower_bound && value < self.upper_bound
63    }
64
65    /// Returns the fraction of this bucket covered by the given range.
66    #[must_use]
67    pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68        let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69        let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71        let bucket_width = self.width();
72        if bucket_width <= 0.0 {
73            return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74                1.0
75            } else {
76                0.0
77            };
78        }
79
80        let overlap = (effective_upper - effective_lower).max(0.0);
81        (overlap / bucket_width).min(1.0)
82    }
83}
84
85/// An equi-depth histogram for selectivity estimation.
86///
87/// Equi-depth histograms partition data into buckets where each bucket
88/// contains approximately the same number of rows. This provides more
89/// accurate selectivity estimates than assuming uniform distribution,
90/// especially for skewed data.
91///
92/// # Example
93///
94/// ```no_run
95/// use grafeo_engine::query::optimizer::cardinality::EquiDepthHistogram;
96///
97/// // Build a histogram from sorted values
98/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0];
99/// let histogram = EquiDepthHistogram::build(&values, 4);
100///
101/// // Estimate selectivity for age > 25
102/// let selectivity = histogram.range_selectivity(Some(25.0), None);
103/// ```
104#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106    /// The histogram buckets, sorted by lower_bound.
107    buckets: Vec<HistogramBucket>,
108    /// Total number of rows represented by this histogram.
109    total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113    /// Creates a new histogram from pre-built buckets.
114    #[must_use]
115    pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116        let total_rows = buckets.iter().map(|b| b.frequency).sum();
117        Self {
118            buckets,
119            total_rows,
120        }
121    }
122
123    /// Builds an equi-depth histogram from a sorted slice of values.
124    ///
125    /// # Arguments
126    /// * `values` - A sorted slice of numeric values
127    /// * `num_buckets` - The desired number of buckets
128    ///
129    /// # Returns
130    /// An equi-depth histogram with approximately equal row counts per bucket.
131    #[must_use]
132    pub fn build(values: &[f64], num_buckets: usize) -> Self {
133        if values.is_empty() || num_buckets == 0 {
134            return Self {
135                buckets: Vec::new(),
136                total_rows: 0,
137            };
138        }
139
140        let num_buckets = num_buckets.min(values.len());
141        let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142        let mut buckets = Vec::with_capacity(num_buckets);
143
144        let mut start_idx = 0;
145        while start_idx < values.len() {
146            let end_idx = (start_idx + rows_per_bucket).min(values.len());
147            let lower_bound = values[start_idx];
148            let upper_bound = if end_idx < values.len() {
149                values[end_idx]
150            } else {
151                // For the last bucket, extend slightly beyond the max value
152                values[end_idx - 1] + 1.0
153            };
154
155            // Count distinct values in this bucket
156            let bucket_values = &values[start_idx..end_idx];
157            let distinct_count = count_distinct(bucket_values);
158
159            buckets.push(HistogramBucket::new(
160                lower_bound,
161                upper_bound,
162                (end_idx - start_idx) as u64,
163                distinct_count,
164            ));
165
166            start_idx = end_idx;
167        }
168
169        Self::new(buckets)
170    }
171
172    /// Returns the number of buckets in this histogram.
173    #[must_use]
174    pub fn num_buckets(&self) -> usize {
175        self.buckets.len()
176    }
177
178    /// Returns the total number of rows represented.
179    #[must_use]
180    pub fn total_rows(&self) -> u64 {
181        self.total_rows
182    }
183
184    /// Returns the histogram buckets.
185    #[must_use]
186    pub fn buckets(&self) -> &[HistogramBucket] {
187        &self.buckets
188    }
189
190    /// Estimates selectivity for a range predicate.
191    ///
192    /// # Arguments
193    /// * `lower` - Lower bound (None for unbounded)
194    /// * `upper` - Upper bound (None for unbounded)
195    ///
196    /// # Returns
197    /// Estimated fraction of rows matching the range (0.0 to 1.0).
198    #[must_use]
199    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200        if self.buckets.is_empty() || self.total_rows == 0 {
201            return 0.33; // Default fallback
202        }
203
204        let mut matching_rows = 0.0;
205
206        for bucket in &self.buckets {
207            // Check if this bucket overlaps with the range
208            let bucket_lower = bucket.lower_bound;
209            let bucket_upper = bucket.upper_bound;
210
211            // Skip buckets entirely outside the range
212            if let Some(l) = lower
213                && bucket_upper <= l
214            {
215                continue;
216            }
217            if let Some(u) = upper
218                && bucket_lower >= u
219            {
220                continue;
221            }
222
223            // Calculate the fraction of this bucket covered by the range
224            let overlap = bucket.overlap_fraction(lower, upper);
225            matching_rows += overlap * bucket.frequency as f64;
226        }
227
228        (matching_rows / self.total_rows as f64).clamp(0.0, 1.0)
229    }
230
231    /// Estimates selectivity for an equality predicate.
232    ///
233    /// Uses the distinct count within matching buckets for better accuracy.
234    #[must_use]
235    pub fn equality_selectivity(&self, value: f64) -> f64 {
236        if self.buckets.is_empty() || self.total_rows == 0 {
237            return 0.01; // Default fallback
238        }
239
240        // Find the bucket containing this value
241        for bucket in &self.buckets {
242            if bucket.contains(value) {
243                // Assume uniform distribution within bucket
244                if bucket.distinct_count > 0 {
245                    return (bucket.frequency as f64
246                        / bucket.distinct_count as f64
247                        / self.total_rows as f64)
248                        .min(1.0);
249                }
250            }
251        }
252
253        // Value not in any bucket - very low selectivity
254        0.001
255    }
256
257    /// Gets the minimum value in the histogram.
258    #[must_use]
259    pub fn min_value(&self) -> Option<f64> {
260        self.buckets.first().map(|b| b.lower_bound)
261    }
262
263    /// Gets the maximum value in the histogram.
264    #[must_use]
265    pub fn max_value(&self) -> Option<f64> {
266        self.buckets.last().map(|b| b.upper_bound)
267    }
268}
269
270/// Counts distinct values in a sorted slice.
271fn count_distinct(sorted_values: &[f64]) -> u64 {
272    if sorted_values.is_empty() {
273        return 0;
274    }
275
276    let mut count = 1u64;
277    let mut prev = sorted_values[0];
278
279    for &val in &sorted_values[1..] {
280        if (val - prev).abs() > f64::EPSILON {
281            count += 1;
282            prev = val;
283        }
284    }
285
286    count
287}
288
289/// Statistics for a table/label.
290#[derive(Debug, Clone)]
291pub struct TableStats {
292    /// Total number of rows.
293    pub row_count: u64,
294    /// Column statistics.
295    pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299    /// Creates new table statistics.
300    #[must_use]
301    pub fn new(row_count: u64) -> Self {
302        Self {
303            row_count,
304            columns: HashMap::new(),
305        }
306    }
307
308    /// Adds column statistics.
309    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310        self.columns.insert(name.to_string(), stats);
311        self
312    }
313}
314
315/// Statistics for a column.
316#[derive(Debug, Clone)]
317pub struct ColumnStats {
318    /// Number of distinct values.
319    pub distinct_count: u64,
320    /// Number of null values.
321    pub null_count: u64,
322    /// Minimum value (if orderable).
323    pub min_value: Option<f64>,
324    /// Maximum value (if orderable).
325    pub max_value: Option<f64>,
326    /// Equi-depth histogram for accurate selectivity estimation.
327    pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331    /// Creates new column statistics.
332    #[must_use]
333    pub fn new(distinct_count: u64) -> Self {
334        Self {
335            distinct_count,
336            null_count: 0,
337            min_value: None,
338            max_value: None,
339            histogram: None,
340        }
341    }
342
343    /// Sets the null count.
344    #[must_use]
345    pub fn with_nulls(mut self, null_count: u64) -> Self {
346        self.null_count = null_count;
347        self
348    }
349
350    /// Sets the min/max range.
351    #[must_use]
352    pub fn with_range(mut self, min: f64, max: f64) -> Self {
353        self.min_value = Some(min);
354        self.max_value = Some(max);
355        self
356    }
357
358    /// Sets the equi-depth histogram for this column.
359    #[must_use]
360    pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361        self.histogram = Some(histogram);
362        self
363    }
364
365    /// Builds column statistics with histogram from raw values.
366    ///
367    /// This is a convenience method that computes all statistics from the data.
368    ///
369    /// # Arguments
370    /// * `values` - The column values (will be sorted internally)
371    /// * `num_buckets` - Number of histogram buckets to create
372    #[must_use]
373    pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374        if values.is_empty() {
375            return Self::new(0);
376        }
377
378        // Sort values for histogram building
379        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381        let min = values.first().copied();
382        let max = values.last().copied();
383        let distinct_count = count_distinct(&values);
384        let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386        Self {
387            distinct_count,
388            null_count: 0,
389            min_value: min,
390            max_value: max,
391            histogram: Some(histogram),
392        }
393    }
394}
395
396/// Configurable selectivity defaults for cardinality estimation.
397///
398/// Controls the assumed selectivity for various predicate types when
399/// histogram or column statistics are unavailable. Adjusting these
400/// values can improve plan quality for workloads with known skew.
401#[derive(Debug, Clone)]
402pub struct SelectivityConfig {
403    /// Selectivity for unknown predicates (default: 0.1).
404    pub default: f64,
405    /// Selectivity for equality predicates without stats (default: 0.01).
406    pub equality: f64,
407    /// Selectivity for inequality predicates (default: 0.99).
408    pub inequality: f64,
409    /// Selectivity for range predicates without stats (default: 0.33).
410    pub range: f64,
411    /// Selectivity for string operations: STARTS WITH, ENDS WITH, CONTAINS, LIKE (default: 0.1).
412    pub string_ops: f64,
413    /// Selectivity for IN membership (default: 0.1).
414    pub membership: f64,
415    /// Selectivity for IS NULL (default: 0.05).
416    pub is_null: f64,
417    /// Selectivity for IS NOT NULL (default: 0.95).
418    pub is_not_null: f64,
419    /// Fraction assumed distinct for DISTINCT operations (default: 0.5).
420    pub distinct_fraction: f64,
421}
422
423impl SelectivityConfig {
424    /// Creates a new config with standard database defaults.
425    #[must_use]
426    pub fn new() -> Self {
427        Self {
428            default: 0.1,
429            equality: 0.01,
430            inequality: 0.99,
431            range: 0.33,
432            string_ops: 0.1,
433            membership: 0.1,
434            is_null: 0.05,
435            is_not_null: 0.95,
436            distinct_fraction: 0.5,
437        }
438    }
439}
440
441impl Default for SelectivityConfig {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447/// A single estimate-vs-actual observation for analysis.
448#[derive(Debug, Clone)]
449pub struct EstimationEntry {
450    /// Human-readable label for the operator (e.g., "NodeScan(Person)").
451    pub operator: String,
452    /// The cardinality estimate produced by the optimizer.
453    pub estimated: f64,
454    /// The actual row count observed at execution time.
455    pub actual: f64,
456}
457
458impl EstimationEntry {
459    /// Returns the estimation error ratio (actual / estimated).
460    ///
461    /// Values near 1.0 indicate accurate estimates.
462    /// Values > 1.0 indicate underestimation.
463    /// Values < 1.0 indicate overestimation.
464    #[must_use]
465    pub fn error_ratio(&self) -> f64 {
466        if self.estimated.abs() < f64::EPSILON {
467            if self.actual.abs() < f64::EPSILON {
468                1.0
469            } else {
470                f64::INFINITY
471            }
472        } else {
473            self.actual / self.estimated
474        }
475    }
476}
477
478/// Collects estimate vs actual cardinality data for query plan analysis.
479///
480/// After executing a query, call [`record()`](Self::record) for each
481/// operator with its estimated and actual cardinalities. Then use
482/// [`should_replan()`](Self::should_replan) to decide whether the plan
483/// should be re-optimized.
484#[derive(Debug, Clone, Default)]
485pub struct EstimationLog {
486    /// Recorded entries.
487    entries: Vec<EstimationEntry>,
488    /// Error ratio threshold that triggers re-planning (default: 10.0).
489    ///
490    /// If any operator's error ratio exceeds this, `should_replan()` returns true.
491    replan_threshold: f64,
492}
493
494impl EstimationLog {
495    /// Creates a new estimation log with the given re-planning threshold.
496    #[must_use]
497    pub fn new(replan_threshold: f64) -> Self {
498        Self {
499            entries: Vec::new(),
500            replan_threshold,
501        }
502    }
503
504    /// Records an estimate-vs-actual observation.
505    pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
506        self.entries.push(EstimationEntry {
507            operator: operator.into(),
508            estimated,
509            actual,
510        });
511    }
512
513    /// Returns all recorded entries.
514    #[must_use]
515    pub fn entries(&self) -> &[EstimationEntry] {
516        &self.entries
517    }
518
519    /// Returns whether any operator's estimation error exceeds the threshold,
520    /// indicating the plan should be re-optimized.
521    #[must_use]
522    pub fn should_replan(&self) -> bool {
523        self.entries.iter().any(|e| {
524            let ratio = e.error_ratio();
525            ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
526        })
527    }
528
529    /// Returns the maximum error ratio across all entries.
530    #[must_use]
531    pub fn max_error_ratio(&self) -> f64 {
532        self.entries
533            .iter()
534            .map(|e| {
535                let r = e.error_ratio();
536                // Normalize so both over- and under-estimation are > 1.0
537                if r < 1.0 { 1.0 / r } else { r }
538            })
539            .fold(1.0_f64, f64::max)
540    }
541
542    /// Clears all entries.
543    pub fn clear(&mut self) {
544        self.entries.clear();
545    }
546}
547
548/// Cardinality estimator.
549pub struct CardinalityEstimator {
550    /// Statistics for each label/table.
551    table_stats: HashMap<String, TableStats>,
552    /// Default row count for unknown tables.
553    default_row_count: u64,
554    /// Default selectivity for unknown predicates.
555    default_selectivity: f64,
556    /// Average edge fanout (outgoing edges per node).
557    avg_fanout: f64,
558    /// Configurable selectivity defaults.
559    selectivity_config: SelectivityConfig,
560}
561
562impl CardinalityEstimator {
563    /// Creates a new cardinality estimator with default settings.
564    #[must_use]
565    pub fn new() -> Self {
566        let config = SelectivityConfig::new();
567        Self {
568            table_stats: HashMap::new(),
569            default_row_count: 1000,
570            default_selectivity: config.default,
571            avg_fanout: 10.0,
572            selectivity_config: config,
573        }
574    }
575
576    /// Creates a new cardinality estimator with custom selectivity configuration.
577    #[must_use]
578    pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
579        Self {
580            table_stats: HashMap::new(),
581            default_row_count: 1000,
582            default_selectivity: config.default,
583            avg_fanout: 10.0,
584            selectivity_config: config,
585        }
586    }
587
588    /// Returns the current selectivity configuration.
589    #[must_use]
590    pub fn selectivity_config(&self) -> &SelectivityConfig {
591        &self.selectivity_config
592    }
593
594    /// Creates an estimation log with the default re-planning threshold (10x).
595    #[must_use]
596    pub fn create_estimation_log() -> EstimationLog {
597        EstimationLog::new(10.0)
598    }
599
600    /// Creates a cardinality estimator pre-populated from store statistics.
601    ///
602    /// Maps `LabelStatistics` to `TableStats` and computes the average edge
603    /// fanout from `EdgeTypeStatistics`. Falls back to defaults for any
604    /// missing statistics.
605    #[must_use]
606    pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
607        let mut estimator = Self::new();
608
609        // Use total node count as default for unlabeled scans
610        if stats.total_nodes > 0 {
611            estimator.default_row_count = stats.total_nodes;
612        }
613
614        // Convert label statistics to optimizer table stats
615        for (label, label_stats) in &stats.labels {
616            let mut table_stats = TableStats::new(label_stats.node_count);
617
618            // Map property statistics (distinct count for selectivity estimation)
619            for (prop, col_stats) in &label_stats.properties {
620                let optimizer_col =
621                    ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
622                table_stats = table_stats.with_column(prop, optimizer_col);
623            }
624
625            estimator.add_table_stats(label, table_stats);
626        }
627
628        // Compute average fanout from edge type statistics
629        if !stats.edge_types.is_empty() {
630            let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
631            estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
632        } else if stats.total_nodes > 0 {
633            estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
634        }
635
636        // Clamp fanout to a reasonable minimum
637        if estimator.avg_fanout < 1.0 {
638            estimator.avg_fanout = 1.0;
639        }
640
641        estimator
642    }
643
644    /// Adds statistics for a table/label.
645    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
646        self.table_stats.insert(name.to_string(), stats);
647    }
648
649    /// Sets the average edge fanout.
650    pub fn set_avg_fanout(&mut self, fanout: f64) {
651        self.avg_fanout = fanout;
652    }
653
654    /// Estimates the cardinality of a logical operator.
655    #[must_use]
656    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
657        match op {
658            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
659            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
660            LogicalOperator::Project(project) => self.estimate_project(project),
661            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
662            LogicalOperator::Join(join) => self.estimate_join(join),
663            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
664            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
665            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
666            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
667            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
668            LogicalOperator::Return(ret) => self.estimate(&ret.input),
669            LogicalOperator::Empty => 0.0,
670            LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
671            LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
672            LogicalOperator::MultiWayJoin(mwj) => self.estimate_multi_way_join(mwj),
673            _ => self.default_row_count as f64,
674        }
675    }
676
677    /// Estimates node scan cardinality.
678    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
679        if let Some(label) = &scan.label
680            && let Some(stats) = self.table_stats.get(label)
681        {
682            return stats.row_count as f64;
683        }
684        // No label filter - scan all nodes
685        self.default_row_count as f64
686    }
687
688    /// Estimates filter cardinality.
689    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
690        let input_cardinality = self.estimate(&filter.input);
691        let selectivity = self.estimate_selectivity(&filter.predicate);
692        (input_cardinality * selectivity).max(1.0)
693    }
694
695    /// Estimates projection cardinality (same as input).
696    fn estimate_project(&self, project: &ProjectOp) -> f64 {
697        self.estimate(&project.input)
698    }
699
700    /// Estimates expand cardinality.
701    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
702        let input_cardinality = self.estimate(&expand.input);
703
704        // Apply fanout based on edge type
705        let fanout = if !expand.edge_types.is_empty() {
706            // Specific edge type(s) typically have lower fanout
707            self.avg_fanout * 0.5
708        } else {
709            self.avg_fanout
710        };
711
712        // Handle variable-length paths
713        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
714            let min = expand.min_hops as f64;
715            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
716            // Geometric series approximation
717            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
718        } else {
719            fanout
720        };
721
722        (input_cardinality * path_multiplier).max(1.0)
723    }
724
725    /// Estimates join cardinality.
726    fn estimate_join(&self, join: &JoinOp) -> f64 {
727        let left_card = self.estimate(&join.left);
728        let right_card = self.estimate(&join.right);
729
730        match join.join_type {
731            JoinType::Cross => left_card * right_card,
732            JoinType::Inner => {
733                // Assume join selectivity based on conditions
734                let selectivity = if join.conditions.is_empty() {
735                    1.0 // Cross join
736                } else {
737                    // Estimate based on number of conditions
738                    0.1_f64.powi(join.conditions.len() as i32)
739                };
740                (left_card * right_card * selectivity).max(1.0)
741            }
742            JoinType::Left => {
743                // Left join returns at least all left rows
744                let inner_card = self.estimate_join(&JoinOp {
745                    left: join.left.clone(),
746                    right: join.right.clone(),
747                    join_type: JoinType::Inner,
748                    conditions: join.conditions.clone(),
749                });
750                inner_card.max(left_card)
751            }
752            JoinType::Right => {
753                // Right join returns at least all right rows
754                let inner_card = self.estimate_join(&JoinOp {
755                    left: join.left.clone(),
756                    right: join.right.clone(),
757                    join_type: JoinType::Inner,
758                    conditions: join.conditions.clone(),
759                });
760                inner_card.max(right_card)
761            }
762            JoinType::Full => {
763                // Full join returns at least max(left, right)
764                let inner_card = self.estimate_join(&JoinOp {
765                    left: join.left.clone(),
766                    right: join.right.clone(),
767                    join_type: JoinType::Inner,
768                    conditions: join.conditions.clone(),
769                });
770                inner_card.max(left_card.max(right_card))
771            }
772            JoinType::Semi => {
773                // Semi join returns at most left cardinality
774                (left_card * self.default_selectivity).max(1.0)
775            }
776            JoinType::Anti => {
777                // Anti join returns at most left cardinality
778                (left_card * (1.0 - self.default_selectivity)).max(1.0)
779            }
780        }
781    }
782
783    /// Estimates aggregation cardinality.
784    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
785        let input_cardinality = self.estimate(&agg.input);
786
787        if agg.group_by.is_empty() {
788            // Global aggregation - single row
789            1.0
790        } else {
791            // Group by - estimate distinct groups
792            // Assume each group key reduces cardinality by 10
793            let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
794            (input_cardinality / group_reduction).max(1.0)
795        }
796    }
797
798    /// Estimates sort cardinality (same as input).
799    fn estimate_sort(&self, sort: &SortOp) -> f64 {
800        self.estimate(&sort.input)
801    }
802
803    /// Estimates distinct cardinality.
804    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
805        let input_cardinality = self.estimate(&distinct.input);
806        (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
807    }
808
809    /// Estimates limit cardinality.
810    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
811        let input_cardinality = self.estimate(&limit.input);
812        (limit.count as f64).min(input_cardinality)
813    }
814
815    /// Estimates skip cardinality.
816    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
817        let input_cardinality = self.estimate(&skip.input);
818        (input_cardinality - skip.count as f64).max(0.0)
819    }
820
821    /// Estimates vector scan cardinality.
822    ///
823    /// Vector scan returns at most k results (the k nearest neighbors).
824    /// With similarity/distance filters, it may return fewer.
825    fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
826        let base_k = scan.k as f64;
827
828        // Apply filter selectivity if thresholds are specified
829        let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
830            // Assume 70% of results pass threshold filters
831            0.7
832        } else {
833            1.0
834        };
835
836        (base_k * selectivity).max(1.0)
837    }
838
839    /// Estimates vector join cardinality.
840    ///
841    /// Vector join produces up to k results per input row.
842    fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
843        let input_cardinality = self.estimate(&join.input);
844        let k = join.k as f64;
845
846        // Apply filter selectivity if thresholds are specified
847        let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
848            0.7
849        } else {
850            1.0
851        };
852
853        (input_cardinality * k * selectivity).max(1.0)
854    }
855
856    /// Estimates multi-way join cardinality using the AGM bound heuristic.
857    ///
858    /// For a cyclic join of N relations, the AGM (Atserias-Grohe-Marx) bound
859    /// gives min(cardinality)^(N/2) as a worst-case output size estimate.
860    fn estimate_multi_way_join(&self, mwj: &MultiWayJoinOp) -> f64 {
861        if mwj.inputs.is_empty() {
862            return 0.0;
863        }
864        let cardinalities: Vec<f64> = mwj
865            .inputs
866            .iter()
867            .map(|input| self.estimate(input))
868            .collect();
869        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
870        let n = cardinalities.len() as f64;
871        // AGM bound: min(cardinality)^(n/2)
872        (min_card.powf(n / 2.0)).max(1.0)
873    }
874
875    /// Estimates the selectivity of a predicate (0.0 to 1.0).
876    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
877        match expr {
878            LogicalExpression::Binary { left, op, right } => {
879                self.estimate_binary_selectivity(left, *op, right)
880            }
881            LogicalExpression::Unary { op, operand } => {
882                self.estimate_unary_selectivity(*op, operand)
883            }
884            LogicalExpression::Literal(value) => {
885                // Boolean literal
886                if let grafeo_common::types::Value::Bool(b) = value {
887                    if *b { 1.0 } else { 0.0 }
888                } else {
889                    self.default_selectivity
890                }
891            }
892            _ => self.default_selectivity,
893        }
894    }
895
896    /// Estimates binary expression selectivity.
897    fn estimate_binary_selectivity(
898        &self,
899        left: &LogicalExpression,
900        op: BinaryOp,
901        right: &LogicalExpression,
902    ) -> f64 {
903        match op {
904            // Equality - try histogram-based estimation
905            BinaryOp::Eq => {
906                if let Some(selectivity) = self.try_equality_selectivity(left, right) {
907                    return selectivity;
908                }
909                self.selectivity_config.equality
910            }
911            // Inequality is very unselective
912            BinaryOp::Ne => self.selectivity_config.inequality,
913            // Range predicates - use histogram if available
914            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
915                if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
916                    return selectivity;
917                }
918                self.selectivity_config.range
919            }
920            // Logical operators - recursively estimate sub-expressions
921            BinaryOp::And => {
922                let left_sel = self.estimate_selectivity(left);
923                let right_sel = self.estimate_selectivity(right);
924                // AND reduces selectivity (multiply assuming independence)
925                left_sel * right_sel
926            }
927            BinaryOp::Or => {
928                let left_sel = self.estimate_selectivity(left);
929                let right_sel = self.estimate_selectivity(right);
930                // OR: P(A ∪ B) = P(A) + P(B) - P(A ∩ B)
931                // Assuming independence: P(A ∩ B) = P(A) * P(B)
932                (left_sel + right_sel - left_sel * right_sel).min(1.0)
933            }
934            // String operations
935            BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
936                self.selectivity_config.string_ops
937            }
938            // Collection membership
939            BinaryOp::In => self.selectivity_config.membership,
940            // Other operations
941            _ => self.default_selectivity,
942        }
943    }
944
945    /// Tries to estimate equality selectivity using histograms.
946    fn try_equality_selectivity(
947        &self,
948        left: &LogicalExpression,
949        right: &LogicalExpression,
950    ) -> Option<f64> {
951        // Extract property access and literal value
952        let (label, column, value) = self.extract_column_and_value(left, right)?;
953
954        // Get column stats with histogram
955        let stats = self.get_column_stats(&label, &column)?;
956
957        // Try histogram-based estimation
958        if let Some(ref histogram) = stats.histogram {
959            return Some(histogram.equality_selectivity(value));
960        }
961
962        // Fall back to distinct count estimation
963        if stats.distinct_count > 0 {
964            return Some(1.0 / stats.distinct_count as f64);
965        }
966
967        None
968    }
969
970    /// Tries to estimate range selectivity using histograms.
971    fn try_range_selectivity(
972        &self,
973        left: &LogicalExpression,
974        op: BinaryOp,
975        right: &LogicalExpression,
976    ) -> Option<f64> {
977        // Extract property access and literal value
978        let (label, column, value) = self.extract_column_and_value(left, right)?;
979
980        // Get column stats
981        let stats = self.get_column_stats(&label, &column)?;
982
983        // Determine the range based on operator
984        let (lower, upper) = match op {
985            BinaryOp::Lt => (None, Some(value)),
986            BinaryOp::Le => (None, Some(value + f64::EPSILON)),
987            BinaryOp::Gt => (Some(value + f64::EPSILON), None),
988            BinaryOp::Ge => (Some(value), None),
989            _ => return None,
990        };
991
992        // Try histogram-based estimation first
993        if let Some(ref histogram) = stats.histogram {
994            return Some(histogram.range_selectivity(lower, upper));
995        }
996
997        // Fall back to min/max range estimation
998        if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
999            let range = max - min;
1000            if range <= 0.0 {
1001                return Some(1.0);
1002            }
1003
1004            let effective_lower = lower.unwrap_or(min).max(min);
1005            let effective_upper = upper.unwrap_or(max).min(max);
1006            let overlap = (effective_upper - effective_lower).max(0.0);
1007            return Some((overlap / range).clamp(0.0, 1.0));
1008        }
1009
1010        None
1011    }
1012
1013    /// Extracts column information and literal value from a comparison.
1014    ///
1015    /// Returns (label, column_name, numeric_value) if the expression is
1016    /// a comparison between a property access and a numeric literal.
1017    fn extract_column_and_value(
1018        &self,
1019        left: &LogicalExpression,
1020        right: &LogicalExpression,
1021    ) -> Option<(String, String, f64)> {
1022        // Try left as property, right as literal
1023        if let Some(result) = self.try_extract_property_literal(left, right) {
1024            return Some(result);
1025        }
1026
1027        // Try right as property, left as literal
1028        self.try_extract_property_literal(right, left)
1029    }
1030
1031    /// Tries to extract property and literal from a specific ordering.
1032    fn try_extract_property_literal(
1033        &self,
1034        property_expr: &LogicalExpression,
1035        literal_expr: &LogicalExpression,
1036    ) -> Option<(String, String, f64)> {
1037        // Extract property access
1038        let (variable, property) = match property_expr {
1039            LogicalExpression::Property { variable, property } => {
1040                (variable.clone(), property.clone())
1041            }
1042            _ => return None,
1043        };
1044
1045        // Extract numeric literal
1046        let value = match literal_expr {
1047            LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1048            LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1049            _ => return None,
1050        };
1051
1052        // Try to find a label for this variable from table stats
1053        // Use the variable name as a heuristic label lookup
1054        // In practice, the optimizer would track which labels variables are bound to
1055        for label in self.table_stats.keys() {
1056            if let Some(stats) = self.table_stats.get(label)
1057                && stats.columns.contains_key(&property)
1058            {
1059                return Some((label.clone(), property, value));
1060            }
1061        }
1062
1063        // If no stats found but we have the property, return with variable as label
1064        Some((variable, property, value))
1065    }
1066
1067    /// Estimates unary expression selectivity.
1068    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1069        match op {
1070            UnaryOp::Not => 1.0 - self.default_selectivity,
1071            UnaryOp::IsNull => self.selectivity_config.is_null,
1072            UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1073            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
1074        }
1075    }
1076
1077    /// Gets statistics for a column.
1078    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1079        self.table_stats.get(label)?.columns.get(column)
1080    }
1081}
1082
1083impl Default for CardinalityEstimator {
1084    fn default() -> Self {
1085        Self::new()
1086    }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091    use super::*;
1092    use crate::query::plan::{
1093        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1094        ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1095    };
1096    use grafeo_common::types::Value;
1097
1098    #[test]
1099    fn test_node_scan_with_stats() {
1100        let mut estimator = CardinalityEstimator::new();
1101        estimator.add_table_stats("Person", TableStats::new(5000));
1102
1103        let scan = LogicalOperator::NodeScan(NodeScanOp {
1104            variable: "n".to_string(),
1105            label: Some("Person".to_string()),
1106            input: None,
1107        });
1108
1109        let cardinality = estimator.estimate(&scan);
1110        assert!((cardinality - 5000.0).abs() < 0.001);
1111    }
1112
1113    #[test]
1114    fn test_filter_reduces_cardinality() {
1115        let mut estimator = CardinalityEstimator::new();
1116        estimator.add_table_stats("Person", TableStats::new(1000));
1117
1118        let filter = LogicalOperator::Filter(FilterOp {
1119            predicate: LogicalExpression::Binary {
1120                left: Box::new(LogicalExpression::Property {
1121                    variable: "n".to_string(),
1122                    property: "age".to_string(),
1123                }),
1124                op: BinaryOp::Eq,
1125                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1126            },
1127            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1128                variable: "n".to_string(),
1129                label: Some("Person".to_string()),
1130                input: None,
1131            })),
1132            pushdown_hint: None,
1133        });
1134
1135        let cardinality = estimator.estimate(&filter);
1136        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
1137        assert!(cardinality < 1000.0);
1138        assert!(cardinality >= 1.0);
1139    }
1140
1141    #[test]
1142    fn test_join_cardinality() {
1143        let mut estimator = CardinalityEstimator::new();
1144        estimator.add_table_stats("Person", TableStats::new(1000));
1145        estimator.add_table_stats("Company", TableStats::new(100));
1146
1147        let join = LogicalOperator::Join(JoinOp {
1148            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1149                variable: "p".to_string(),
1150                label: Some("Person".to_string()),
1151                input: None,
1152            })),
1153            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1154                variable: "c".to_string(),
1155                label: Some("Company".to_string()),
1156                input: None,
1157            })),
1158            join_type: JoinType::Inner,
1159            conditions: vec![JoinCondition {
1160                left: LogicalExpression::Property {
1161                    variable: "p".to_string(),
1162                    property: "company_id".to_string(),
1163                },
1164                right: LogicalExpression::Property {
1165                    variable: "c".to_string(),
1166                    property: "id".to_string(),
1167                },
1168            }],
1169        });
1170
1171        let cardinality = estimator.estimate(&join);
1172        // Should be less than cross product
1173        assert!(cardinality < 1000.0 * 100.0);
1174    }
1175
1176    #[test]
1177    fn test_limit_caps_cardinality() {
1178        let mut estimator = CardinalityEstimator::new();
1179        estimator.add_table_stats("Person", TableStats::new(1000));
1180
1181        let limit = LogicalOperator::Limit(LimitOp {
1182            count: 10,
1183            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1184                variable: "n".to_string(),
1185                label: Some("Person".to_string()),
1186                input: None,
1187            })),
1188        });
1189
1190        let cardinality = estimator.estimate(&limit);
1191        assert!((cardinality - 10.0).abs() < 0.001);
1192    }
1193
1194    #[test]
1195    fn test_aggregate_reduces_cardinality() {
1196        let mut estimator = CardinalityEstimator::new();
1197        estimator.add_table_stats("Person", TableStats::new(1000));
1198
1199        // Global aggregation
1200        let global_agg = LogicalOperator::Aggregate(AggregateOp {
1201            group_by: vec![],
1202            aggregates: vec![],
1203            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1204                variable: "n".to_string(),
1205                label: Some("Person".to_string()),
1206                input: None,
1207            })),
1208            having: None,
1209        });
1210
1211        let cardinality = estimator.estimate(&global_agg);
1212        assert!((cardinality - 1.0).abs() < 0.001);
1213
1214        // Group by aggregation
1215        let group_agg = LogicalOperator::Aggregate(AggregateOp {
1216            group_by: vec![LogicalExpression::Property {
1217                variable: "n".to_string(),
1218                property: "city".to_string(),
1219            }],
1220            aggregates: vec![],
1221            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1222                variable: "n".to_string(),
1223                label: Some("Person".to_string()),
1224                input: None,
1225            })),
1226            having: None,
1227        });
1228
1229        let cardinality = estimator.estimate(&group_agg);
1230        // Should be less than input
1231        assert!(cardinality < 1000.0);
1232    }
1233
1234    #[test]
1235    fn test_node_scan_without_stats() {
1236        let estimator = CardinalityEstimator::new();
1237
1238        let scan = LogicalOperator::NodeScan(NodeScanOp {
1239            variable: "n".to_string(),
1240            label: Some("Unknown".to_string()),
1241            input: None,
1242        });
1243
1244        let cardinality = estimator.estimate(&scan);
1245        // Should return default (1000)
1246        assert!((cardinality - 1000.0).abs() < 0.001);
1247    }
1248
1249    #[test]
1250    fn test_node_scan_no_label() {
1251        let estimator = CardinalityEstimator::new();
1252
1253        let scan = LogicalOperator::NodeScan(NodeScanOp {
1254            variable: "n".to_string(),
1255            label: None,
1256            input: None,
1257        });
1258
1259        let cardinality = estimator.estimate(&scan);
1260        // Should scan all nodes (default)
1261        assert!((cardinality - 1000.0).abs() < 0.001);
1262    }
1263
1264    #[test]
1265    fn test_filter_inequality_selectivity() {
1266        let mut estimator = CardinalityEstimator::new();
1267        estimator.add_table_stats("Person", TableStats::new(1000));
1268
1269        let filter = LogicalOperator::Filter(FilterOp {
1270            predicate: LogicalExpression::Binary {
1271                left: Box::new(LogicalExpression::Property {
1272                    variable: "n".to_string(),
1273                    property: "age".to_string(),
1274                }),
1275                op: BinaryOp::Ne,
1276                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1277            },
1278            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1279                variable: "n".to_string(),
1280                label: Some("Person".to_string()),
1281                input: None,
1282            })),
1283            pushdown_hint: None,
1284        });
1285
1286        let cardinality = estimator.estimate(&filter);
1287        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
1288        assert!(cardinality > 900.0);
1289    }
1290
1291    #[test]
1292    fn test_filter_range_selectivity() {
1293        let mut estimator = CardinalityEstimator::new();
1294        estimator.add_table_stats("Person", TableStats::new(1000));
1295
1296        let filter = LogicalOperator::Filter(FilterOp {
1297            predicate: LogicalExpression::Binary {
1298                left: Box::new(LogicalExpression::Property {
1299                    variable: "n".to_string(),
1300                    property: "age".to_string(),
1301                }),
1302                op: BinaryOp::Gt,
1303                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1304            },
1305            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1306                variable: "n".to_string(),
1307                label: Some("Person".to_string()),
1308                input: None,
1309            })),
1310            pushdown_hint: None,
1311        });
1312
1313        let cardinality = estimator.estimate(&filter);
1314        // Range selectivity is 0.33, so 1000 * 0.33 = 330
1315        assert!(cardinality < 500.0);
1316        assert!(cardinality > 100.0);
1317    }
1318
1319    #[test]
1320    fn test_filter_and_selectivity() {
1321        let mut estimator = CardinalityEstimator::new();
1322        estimator.add_table_stats("Person", TableStats::new(1000));
1323
1324        // Test AND with two equality predicates
1325        // Each equality has selectivity 0.01, so AND gives 0.01 * 0.01 = 0.0001
1326        let filter = LogicalOperator::Filter(FilterOp {
1327            predicate: LogicalExpression::Binary {
1328                left: Box::new(LogicalExpression::Binary {
1329                    left: Box::new(LogicalExpression::Property {
1330                        variable: "n".to_string(),
1331                        property: "city".to_string(),
1332                    }),
1333                    op: BinaryOp::Eq,
1334                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1335                }),
1336                op: BinaryOp::And,
1337                right: Box::new(LogicalExpression::Binary {
1338                    left: Box::new(LogicalExpression::Property {
1339                        variable: "n".to_string(),
1340                        property: "age".to_string(),
1341                    }),
1342                    op: BinaryOp::Eq,
1343                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1344                }),
1345            },
1346            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1347                variable: "n".to_string(),
1348                label: Some("Person".to_string()),
1349                input: None,
1350            })),
1351            pushdown_hint: None,
1352        });
1353
1354        let cardinality = estimator.estimate(&filter);
1355        // AND reduces selectivity (multiply): 0.01 * 0.01 = 0.0001
1356        // 1000 * 0.0001 = 0.1, min is 1.0
1357        assert!(cardinality < 100.0);
1358        assert!(cardinality >= 1.0);
1359    }
1360
1361    #[test]
1362    fn test_filter_or_selectivity() {
1363        let mut estimator = CardinalityEstimator::new();
1364        estimator.add_table_stats("Person", TableStats::new(1000));
1365
1366        // Test OR with two equality predicates
1367        // Each equality has selectivity 0.01
1368        // OR gives: 0.01 + 0.01 - (0.01 * 0.01) = 0.0199
1369        let filter = LogicalOperator::Filter(FilterOp {
1370            predicate: LogicalExpression::Binary {
1371                left: Box::new(LogicalExpression::Binary {
1372                    left: Box::new(LogicalExpression::Property {
1373                        variable: "n".to_string(),
1374                        property: "city".to_string(),
1375                    }),
1376                    op: BinaryOp::Eq,
1377                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1378                }),
1379                op: BinaryOp::Or,
1380                right: Box::new(LogicalExpression::Binary {
1381                    left: Box::new(LogicalExpression::Property {
1382                        variable: "n".to_string(),
1383                        property: "city".to_string(),
1384                    }),
1385                    op: BinaryOp::Eq,
1386                    right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1387                }),
1388            },
1389            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1390                variable: "n".to_string(),
1391                label: Some("Person".to_string()),
1392                input: None,
1393            })),
1394            pushdown_hint: None,
1395        });
1396
1397        let cardinality = estimator.estimate(&filter);
1398        // OR: 0.01 + 0.01 - 0.0001 ≈ 0.0199, so 1000 * 0.0199 ≈ 19.9
1399        assert!(cardinality < 100.0);
1400        assert!(cardinality >= 1.0);
1401    }
1402
1403    #[test]
1404    fn test_filter_literal_true() {
1405        let mut estimator = CardinalityEstimator::new();
1406        estimator.add_table_stats("Person", TableStats::new(1000));
1407
1408        let filter = LogicalOperator::Filter(FilterOp {
1409            predicate: LogicalExpression::Literal(Value::Bool(true)),
1410            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1411                variable: "n".to_string(),
1412                label: Some("Person".to_string()),
1413                input: None,
1414            })),
1415            pushdown_hint: None,
1416        });
1417
1418        let cardinality = estimator.estimate(&filter);
1419        // Literal true has selectivity 1.0
1420        assert!((cardinality - 1000.0).abs() < 0.001);
1421    }
1422
1423    #[test]
1424    fn test_filter_literal_false() {
1425        let mut estimator = CardinalityEstimator::new();
1426        estimator.add_table_stats("Person", TableStats::new(1000));
1427
1428        let filter = LogicalOperator::Filter(FilterOp {
1429            predicate: LogicalExpression::Literal(Value::Bool(false)),
1430            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1431                variable: "n".to_string(),
1432                label: Some("Person".to_string()),
1433                input: None,
1434            })),
1435            pushdown_hint: None,
1436        });
1437
1438        let cardinality = estimator.estimate(&filter);
1439        // Literal false has selectivity 0.0, but min is 1.0
1440        assert!((cardinality - 1.0).abs() < 0.001);
1441    }
1442
1443    #[test]
1444    fn test_unary_not_selectivity() {
1445        let mut estimator = CardinalityEstimator::new();
1446        estimator.add_table_stats("Person", TableStats::new(1000));
1447
1448        let filter = LogicalOperator::Filter(FilterOp {
1449            predicate: LogicalExpression::Unary {
1450                op: UnaryOp::Not,
1451                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1452            },
1453            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1454                variable: "n".to_string(),
1455                label: Some("Person".to_string()),
1456                input: None,
1457            })),
1458            pushdown_hint: None,
1459        });
1460
1461        let cardinality = estimator.estimate(&filter);
1462        // NOT inverts selectivity
1463        assert!(cardinality < 1000.0);
1464    }
1465
1466    #[test]
1467    fn test_unary_is_null_selectivity() {
1468        let mut estimator = CardinalityEstimator::new();
1469        estimator.add_table_stats("Person", TableStats::new(1000));
1470
1471        let filter = LogicalOperator::Filter(FilterOp {
1472            predicate: LogicalExpression::Unary {
1473                op: UnaryOp::IsNull,
1474                operand: Box::new(LogicalExpression::Variable("x".to_string())),
1475            },
1476            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1477                variable: "n".to_string(),
1478                label: Some("Person".to_string()),
1479                input: None,
1480            })),
1481            pushdown_hint: None,
1482        });
1483
1484        let cardinality = estimator.estimate(&filter);
1485        // IS NULL has selectivity 0.05
1486        assert!(cardinality < 100.0);
1487    }
1488
1489    #[test]
1490    fn test_expand_cardinality() {
1491        let mut estimator = CardinalityEstimator::new();
1492        estimator.add_table_stats("Person", TableStats::new(100));
1493
1494        let expand = LogicalOperator::Expand(ExpandOp {
1495            from_variable: "a".to_string(),
1496            to_variable: "b".to_string(),
1497            edge_variable: None,
1498            direction: ExpandDirection::Outgoing,
1499            edge_types: vec![],
1500            min_hops: 1,
1501            max_hops: Some(1),
1502            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1503                variable: "a".to_string(),
1504                label: Some("Person".to_string()),
1505                input: None,
1506            })),
1507            path_alias: None,
1508            path_mode: PathMode::Walk,
1509        });
1510
1511        let cardinality = estimator.estimate(&expand);
1512        // Expand multiplies by fanout (10)
1513        assert!(cardinality > 100.0);
1514    }
1515
1516    #[test]
1517    fn test_expand_with_edge_type_filter() {
1518        let mut estimator = CardinalityEstimator::new();
1519        estimator.add_table_stats("Person", TableStats::new(100));
1520
1521        let expand = LogicalOperator::Expand(ExpandOp {
1522            from_variable: "a".to_string(),
1523            to_variable: "b".to_string(),
1524            edge_variable: None,
1525            direction: ExpandDirection::Outgoing,
1526            edge_types: vec!["KNOWS".to_string()],
1527            min_hops: 1,
1528            max_hops: Some(1),
1529            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1530                variable: "a".to_string(),
1531                label: Some("Person".to_string()),
1532                input: None,
1533            })),
1534            path_alias: None,
1535            path_mode: PathMode::Walk,
1536        });
1537
1538        let cardinality = estimator.estimate(&expand);
1539        // With edge type, fanout is reduced by half
1540        assert!(cardinality > 100.0);
1541    }
1542
1543    #[test]
1544    fn test_expand_variable_length() {
1545        let mut estimator = CardinalityEstimator::new();
1546        estimator.add_table_stats("Person", TableStats::new(100));
1547
1548        let expand = LogicalOperator::Expand(ExpandOp {
1549            from_variable: "a".to_string(),
1550            to_variable: "b".to_string(),
1551            edge_variable: None,
1552            direction: ExpandDirection::Outgoing,
1553            edge_types: vec![],
1554            min_hops: 1,
1555            max_hops: Some(3),
1556            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1557                variable: "a".to_string(),
1558                label: Some("Person".to_string()),
1559                input: None,
1560            })),
1561            path_alias: None,
1562            path_mode: PathMode::Walk,
1563        });
1564
1565        let cardinality = estimator.estimate(&expand);
1566        // Variable length path has much higher cardinality
1567        assert!(cardinality > 500.0);
1568    }
1569
1570    #[test]
1571    fn test_join_cross_product() {
1572        let mut estimator = CardinalityEstimator::new();
1573        estimator.add_table_stats("Person", TableStats::new(100));
1574        estimator.add_table_stats("Company", TableStats::new(50));
1575
1576        let join = LogicalOperator::Join(JoinOp {
1577            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1578                variable: "p".to_string(),
1579                label: Some("Person".to_string()),
1580                input: None,
1581            })),
1582            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1583                variable: "c".to_string(),
1584                label: Some("Company".to_string()),
1585                input: None,
1586            })),
1587            join_type: JoinType::Cross,
1588            conditions: vec![],
1589        });
1590
1591        let cardinality = estimator.estimate(&join);
1592        // Cross join = 100 * 50 = 5000
1593        assert!((cardinality - 5000.0).abs() < 0.001);
1594    }
1595
1596    #[test]
1597    fn test_join_left_outer() {
1598        let mut estimator = CardinalityEstimator::new();
1599        estimator.add_table_stats("Person", TableStats::new(1000));
1600        estimator.add_table_stats("Company", TableStats::new(10));
1601
1602        let join = LogicalOperator::Join(JoinOp {
1603            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1604                variable: "p".to_string(),
1605                label: Some("Person".to_string()),
1606                input: None,
1607            })),
1608            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1609                variable: "c".to_string(),
1610                label: Some("Company".to_string()),
1611                input: None,
1612            })),
1613            join_type: JoinType::Left,
1614            conditions: vec![JoinCondition {
1615                left: LogicalExpression::Variable("p".to_string()),
1616                right: LogicalExpression::Variable("c".to_string()),
1617            }],
1618        });
1619
1620        let cardinality = estimator.estimate(&join);
1621        // Left join returns at least all left rows
1622        assert!(cardinality >= 1000.0);
1623    }
1624
1625    #[test]
1626    fn test_join_semi() {
1627        let mut estimator = CardinalityEstimator::new();
1628        estimator.add_table_stats("Person", TableStats::new(1000));
1629        estimator.add_table_stats("Company", TableStats::new(100));
1630
1631        let join = LogicalOperator::Join(JoinOp {
1632            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1633                variable: "p".to_string(),
1634                label: Some("Person".to_string()),
1635                input: None,
1636            })),
1637            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1638                variable: "c".to_string(),
1639                label: Some("Company".to_string()),
1640                input: None,
1641            })),
1642            join_type: JoinType::Semi,
1643            conditions: vec![],
1644        });
1645
1646        let cardinality = estimator.estimate(&join);
1647        // Semi join returns at most left cardinality
1648        assert!(cardinality <= 1000.0);
1649    }
1650
1651    #[test]
1652    fn test_join_anti() {
1653        let mut estimator = CardinalityEstimator::new();
1654        estimator.add_table_stats("Person", TableStats::new(1000));
1655        estimator.add_table_stats("Company", TableStats::new(100));
1656
1657        let join = LogicalOperator::Join(JoinOp {
1658            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1659                variable: "p".to_string(),
1660                label: Some("Person".to_string()),
1661                input: None,
1662            })),
1663            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1664                variable: "c".to_string(),
1665                label: Some("Company".to_string()),
1666                input: None,
1667            })),
1668            join_type: JoinType::Anti,
1669            conditions: vec![],
1670        });
1671
1672        let cardinality = estimator.estimate(&join);
1673        // Anti join returns at most left cardinality
1674        assert!(cardinality <= 1000.0);
1675        assert!(cardinality >= 1.0);
1676    }
1677
1678    #[test]
1679    fn test_project_preserves_cardinality() {
1680        let mut estimator = CardinalityEstimator::new();
1681        estimator.add_table_stats("Person", TableStats::new(1000));
1682
1683        let project = LogicalOperator::Project(ProjectOp {
1684            projections: vec![Projection {
1685                expression: LogicalExpression::Variable("n".to_string()),
1686                alias: None,
1687            }],
1688            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1689                variable: "n".to_string(),
1690                label: Some("Person".to_string()),
1691                input: None,
1692            })),
1693        });
1694
1695        let cardinality = estimator.estimate(&project);
1696        assert!((cardinality - 1000.0).abs() < 0.001);
1697    }
1698
1699    #[test]
1700    fn test_sort_preserves_cardinality() {
1701        let mut estimator = CardinalityEstimator::new();
1702        estimator.add_table_stats("Person", TableStats::new(1000));
1703
1704        let sort = LogicalOperator::Sort(SortOp {
1705            keys: vec![SortKey {
1706                expression: LogicalExpression::Variable("n".to_string()),
1707                order: SortOrder::Ascending,
1708                nulls: None,
1709            }],
1710            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1711                variable: "n".to_string(),
1712                label: Some("Person".to_string()),
1713                input: None,
1714            })),
1715        });
1716
1717        let cardinality = estimator.estimate(&sort);
1718        assert!((cardinality - 1000.0).abs() < 0.001);
1719    }
1720
1721    #[test]
1722    fn test_distinct_reduces_cardinality() {
1723        let mut estimator = CardinalityEstimator::new();
1724        estimator.add_table_stats("Person", TableStats::new(1000));
1725
1726        let distinct = LogicalOperator::Distinct(DistinctOp {
1727            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1728                variable: "n".to_string(),
1729                label: Some("Person".to_string()),
1730                input: None,
1731            })),
1732            columns: None,
1733        });
1734
1735        let cardinality = estimator.estimate(&distinct);
1736        // Distinct assumes 50% unique
1737        assert!((cardinality - 500.0).abs() < 0.001);
1738    }
1739
1740    #[test]
1741    fn test_skip_reduces_cardinality() {
1742        let mut estimator = CardinalityEstimator::new();
1743        estimator.add_table_stats("Person", TableStats::new(1000));
1744
1745        let skip = LogicalOperator::Skip(SkipOp {
1746            count: 100,
1747            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1748                variable: "n".to_string(),
1749                label: Some("Person".to_string()),
1750                input: None,
1751            })),
1752        });
1753
1754        let cardinality = estimator.estimate(&skip);
1755        assert!((cardinality - 900.0).abs() < 0.001);
1756    }
1757
1758    #[test]
1759    fn test_return_preserves_cardinality() {
1760        let mut estimator = CardinalityEstimator::new();
1761        estimator.add_table_stats("Person", TableStats::new(1000));
1762
1763        let ret = LogicalOperator::Return(ReturnOp {
1764            items: vec![ReturnItem {
1765                expression: LogicalExpression::Variable("n".to_string()),
1766                alias: None,
1767            }],
1768            distinct: false,
1769            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1770                variable: "n".to_string(),
1771                label: Some("Person".to_string()),
1772                input: None,
1773            })),
1774        });
1775
1776        let cardinality = estimator.estimate(&ret);
1777        assert!((cardinality - 1000.0).abs() < 0.001);
1778    }
1779
1780    #[test]
1781    fn test_empty_cardinality() {
1782        let estimator = CardinalityEstimator::new();
1783        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1784        assert!((cardinality).abs() < 0.001);
1785    }
1786
1787    #[test]
1788    fn test_table_stats_with_column() {
1789        let stats = TableStats::new(1000).with_column(
1790            "age",
1791            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1792        );
1793
1794        assert_eq!(stats.row_count, 1000);
1795        let col = stats.columns.get("age").unwrap();
1796        assert_eq!(col.distinct_count, 50);
1797        assert_eq!(col.null_count, 10);
1798        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1799        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1800    }
1801
1802    #[test]
1803    fn test_estimator_default() {
1804        let estimator = CardinalityEstimator::default();
1805        let scan = LogicalOperator::NodeScan(NodeScanOp {
1806            variable: "n".to_string(),
1807            label: None,
1808            input: None,
1809        });
1810        let cardinality = estimator.estimate(&scan);
1811        assert!((cardinality - 1000.0).abs() < 0.001);
1812    }
1813
1814    #[test]
1815    fn test_set_avg_fanout() {
1816        let mut estimator = CardinalityEstimator::new();
1817        estimator.add_table_stats("Person", TableStats::new(100));
1818        estimator.set_avg_fanout(5.0);
1819
1820        let expand = LogicalOperator::Expand(ExpandOp {
1821            from_variable: "a".to_string(),
1822            to_variable: "b".to_string(),
1823            edge_variable: None,
1824            direction: ExpandDirection::Outgoing,
1825            edge_types: vec![],
1826            min_hops: 1,
1827            max_hops: Some(1),
1828            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1829                variable: "a".to_string(),
1830                label: Some("Person".to_string()),
1831                input: None,
1832            })),
1833            path_alias: None,
1834            path_mode: PathMode::Walk,
1835        });
1836
1837        let cardinality = estimator.estimate(&expand);
1838        // With fanout 5: 100 * 5 = 500
1839        assert!((cardinality - 500.0).abs() < 0.001);
1840    }
1841
1842    #[test]
1843    fn test_multiple_group_by_keys_reduce_cardinality() {
1844        // The current implementation uses a simplified model where more group by keys
1845        // results in greater reduction (dividing by 10^num_keys). This is a simplification
1846        // that works for most cases where group by keys are correlated.
1847        let mut estimator = CardinalityEstimator::new();
1848        estimator.add_table_stats("Person", TableStats::new(10000));
1849
1850        let single_group = LogicalOperator::Aggregate(AggregateOp {
1851            group_by: vec![LogicalExpression::Property {
1852                variable: "n".to_string(),
1853                property: "city".to_string(),
1854            }],
1855            aggregates: vec![],
1856            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1857                variable: "n".to_string(),
1858                label: Some("Person".to_string()),
1859                input: None,
1860            })),
1861            having: None,
1862        });
1863
1864        let multi_group = LogicalOperator::Aggregate(AggregateOp {
1865            group_by: vec![
1866                LogicalExpression::Property {
1867                    variable: "n".to_string(),
1868                    property: "city".to_string(),
1869                },
1870                LogicalExpression::Property {
1871                    variable: "n".to_string(),
1872                    property: "country".to_string(),
1873                },
1874            ],
1875            aggregates: vec![],
1876            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1877                variable: "n".to_string(),
1878                label: Some("Person".to_string()),
1879                input: None,
1880            })),
1881            having: None,
1882        });
1883
1884        let single_card = estimator.estimate(&single_group);
1885        let multi_card = estimator.estimate(&multi_group);
1886
1887        // Both should reduce cardinality from input
1888        assert!(single_card < 10000.0);
1889        assert!(multi_card < 10000.0);
1890        // Both should be at least 1
1891        assert!(single_card >= 1.0);
1892        assert!(multi_card >= 1.0);
1893    }
1894
1895    // ============= Histogram Tests =============
1896
1897    #[test]
1898    fn test_histogram_build_uniform() {
1899        // Build histogram from uniformly distributed data
1900        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1901        let histogram = EquiDepthHistogram::build(&values, 10);
1902
1903        assert_eq!(histogram.num_buckets(), 10);
1904        assert_eq!(histogram.total_rows(), 100);
1905
1906        // Each bucket should have approximately 10 rows
1907        for bucket in histogram.buckets() {
1908            assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1909        }
1910    }
1911
1912    #[test]
1913    fn test_histogram_build_skewed() {
1914        // Build histogram from skewed data (many small values, few large)
1915        let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1916        values.extend((0..20).map(|i| 1000.0 + i as f64));
1917        let histogram = EquiDepthHistogram::build(&values, 5);
1918
1919        assert_eq!(histogram.num_buckets(), 5);
1920        assert_eq!(histogram.total_rows(), 100);
1921
1922        // Each bucket should have ~20 rows despite skewed data
1923        for bucket in histogram.buckets() {
1924            assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1925        }
1926    }
1927
1928    #[test]
1929    fn test_histogram_range_selectivity_full() {
1930        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1931        let histogram = EquiDepthHistogram::build(&values, 10);
1932
1933        // Full range should have selectivity ~1.0
1934        let selectivity = histogram.range_selectivity(None, None);
1935        assert!((selectivity - 1.0).abs() < 0.01);
1936    }
1937
1938    #[test]
1939    fn test_histogram_range_selectivity_half() {
1940        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1941        let histogram = EquiDepthHistogram::build(&values, 10);
1942
1943        // Values >= 50 should be ~50% (half the data)
1944        let selectivity = histogram.range_selectivity(Some(50.0), None);
1945        assert!(selectivity > 0.4 && selectivity < 0.6);
1946    }
1947
1948    #[test]
1949    fn test_histogram_range_selectivity_quarter() {
1950        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1951        let histogram = EquiDepthHistogram::build(&values, 10);
1952
1953        // Values < 25 should be ~25%
1954        let selectivity = histogram.range_selectivity(None, Some(25.0));
1955        assert!(selectivity > 0.2 && selectivity < 0.3);
1956    }
1957
1958    #[test]
1959    fn test_histogram_equality_selectivity() {
1960        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1961        let histogram = EquiDepthHistogram::build(&values, 10);
1962
1963        // Equality on 100 distinct values should be ~1%
1964        let selectivity = histogram.equality_selectivity(50.0);
1965        assert!(selectivity > 0.005 && selectivity < 0.02);
1966    }
1967
1968    #[test]
1969    fn test_histogram_empty() {
1970        let histogram = EquiDepthHistogram::build(&[], 10);
1971
1972        assert_eq!(histogram.num_buckets(), 0);
1973        assert_eq!(histogram.total_rows(), 0);
1974
1975        // Default selectivity for empty histogram
1976        let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1977        assert!((selectivity - 0.33).abs() < 0.01);
1978    }
1979
1980    #[test]
1981    fn test_histogram_bucket_overlap() {
1982        let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1983
1984        // Full overlap
1985        assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1986
1987        // Half overlap (lower half)
1988        assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1989
1990        // Half overlap (upper half)
1991        assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1992
1993        // No overlap (below)
1994        assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1995
1996        // No overlap (above)
1997        assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1998    }
1999
2000    #[test]
2001    fn test_column_stats_from_values() {
2002        let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2003        let stats = ColumnStats::from_values(values, 4);
2004
2005        assert_eq!(stats.distinct_count, 5); // 10, 20, 30, 40, 50
2006        assert!(stats.min_value.is_some());
2007        assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2008        assert!(stats.max_value.is_some());
2009        assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2010        assert!(stats.histogram.is_some());
2011    }
2012
2013    #[test]
2014    fn test_column_stats_with_histogram_builder() {
2015        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2016        let histogram = EquiDepthHistogram::build(&values, 10);
2017
2018        let stats = ColumnStats::new(100)
2019            .with_range(0.0, 99.0)
2020            .with_histogram(histogram);
2021
2022        assert!(stats.histogram.is_some());
2023        assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2024    }
2025
2026    #[test]
2027    fn test_filter_with_histogram_stats() {
2028        let mut estimator = CardinalityEstimator::new();
2029
2030        // Create stats with histogram for age column
2031        let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2032        let histogram = EquiDepthHistogram::build(&age_values, 10);
2033        let age_stats = ColumnStats::new(62)
2034            .with_range(18.0, 79.0)
2035            .with_histogram(histogram);
2036
2037        estimator.add_table_stats(
2038            "Person",
2039            TableStats::new(1000).with_column("age", age_stats),
2040        );
2041
2042        // Filter: age > 50
2043        // Age range is 18-79, so >50 is about (79-50)/(79-18) = 29/61 ≈ 47.5%
2044        let filter = LogicalOperator::Filter(FilterOp {
2045            predicate: LogicalExpression::Binary {
2046                left: Box::new(LogicalExpression::Property {
2047                    variable: "n".to_string(),
2048                    property: "age".to_string(),
2049                }),
2050                op: BinaryOp::Gt,
2051                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2052            },
2053            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2054                variable: "n".to_string(),
2055                label: Some("Person".to_string()),
2056                input: None,
2057            })),
2058            pushdown_hint: None,
2059        });
2060
2061        let cardinality = estimator.estimate(&filter);
2062
2063        // With histogram, should get more accurate estimate than default 0.33
2064        // Expected: ~47.5% of 1000 = ~475
2065        assert!(cardinality > 300.0 && cardinality < 600.0);
2066    }
2067
2068    #[test]
2069    fn test_filter_equality_with_histogram() {
2070        let mut estimator = CardinalityEstimator::new();
2071
2072        // Create stats with histogram
2073        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2074        let histogram = EquiDepthHistogram::build(&values, 10);
2075        let stats = ColumnStats::new(100)
2076            .with_range(0.0, 99.0)
2077            .with_histogram(histogram);
2078
2079        estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2080
2081        // Filter: value = 50
2082        let filter = LogicalOperator::Filter(FilterOp {
2083            predicate: LogicalExpression::Binary {
2084                left: Box::new(LogicalExpression::Property {
2085                    variable: "d".to_string(),
2086                    property: "value".to_string(),
2087                }),
2088                op: BinaryOp::Eq,
2089                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2090            },
2091            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2092                variable: "d".to_string(),
2093                label: Some("Data".to_string()),
2094                input: None,
2095            })),
2096            pushdown_hint: None,
2097        });
2098
2099        let cardinality = estimator.estimate(&filter);
2100
2101        // With 100 distinct values, selectivity should be ~1%
2102        // 1000 * 0.01 = 10
2103        assert!((1.0..50.0).contains(&cardinality));
2104    }
2105
2106    #[test]
2107    fn test_histogram_min_max() {
2108        let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2109        let histogram = EquiDepthHistogram::build(&values, 2);
2110
2111        assert_eq!(histogram.min_value(), Some(5.0));
2112        // Max is the upper bound of the last bucket
2113        assert!(histogram.max_value().is_some());
2114    }
2115
2116    // ==================== SelectivityConfig Tests ====================
2117
2118    #[test]
2119    fn test_selectivity_config_defaults() {
2120        let config = SelectivityConfig::new();
2121        assert!((config.default - 0.1).abs() < f64::EPSILON);
2122        assert!((config.equality - 0.01).abs() < f64::EPSILON);
2123        assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2124        assert!((config.range - 0.33).abs() < f64::EPSILON);
2125        assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2126        assert!((config.membership - 0.1).abs() < f64::EPSILON);
2127        assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2128        assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2129        assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2130    }
2131
2132    #[test]
2133    fn test_custom_selectivity_config() {
2134        let config = SelectivityConfig {
2135            equality: 0.05,
2136            range: 0.25,
2137            ..SelectivityConfig::new()
2138        };
2139        let estimator = CardinalityEstimator::with_selectivity_config(config);
2140        assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2141        assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2142    }
2143
2144    #[test]
2145    fn test_custom_selectivity_affects_estimation() {
2146        // Default: equality = 0.01 → 1000 * 0.01 = 10
2147        let mut default_est = CardinalityEstimator::new();
2148        default_est.add_table_stats("Person", TableStats::new(1000));
2149
2150        let filter = LogicalOperator::Filter(FilterOp {
2151            predicate: LogicalExpression::Binary {
2152                left: Box::new(LogicalExpression::Property {
2153                    variable: "n".to_string(),
2154                    property: "name".to_string(),
2155                }),
2156                op: BinaryOp::Eq,
2157                right: Box::new(LogicalExpression::Literal(Value::String("Alix".into()))),
2158            },
2159            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2160                variable: "n".to_string(),
2161                label: Some("Person".to_string()),
2162                input: None,
2163            })),
2164            pushdown_hint: None,
2165        });
2166
2167        let default_card = default_est.estimate(&filter);
2168
2169        // Custom: equality = 0.2 → 1000 * 0.2 = 200
2170        let config = SelectivityConfig {
2171            equality: 0.2,
2172            ..SelectivityConfig::new()
2173        };
2174        let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2175        custom_est.add_table_stats("Person", TableStats::new(1000));
2176
2177        let custom_card = custom_est.estimate(&filter);
2178
2179        assert!(custom_card > default_card);
2180        assert!((custom_card - 200.0).abs() < 1.0);
2181    }
2182
2183    #[test]
2184    fn test_custom_range_selectivity() {
2185        let config = SelectivityConfig {
2186            range: 0.5,
2187            ..SelectivityConfig::new()
2188        };
2189        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2190        estimator.add_table_stats("Person", TableStats::new(1000));
2191
2192        let filter = LogicalOperator::Filter(FilterOp {
2193            predicate: LogicalExpression::Binary {
2194                left: Box::new(LogicalExpression::Property {
2195                    variable: "n".to_string(),
2196                    property: "age".to_string(),
2197                }),
2198                op: BinaryOp::Gt,
2199                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2200            },
2201            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2202                variable: "n".to_string(),
2203                label: Some("Person".to_string()),
2204                input: None,
2205            })),
2206            pushdown_hint: None,
2207        });
2208
2209        let cardinality = estimator.estimate(&filter);
2210        // 1000 * 0.5 = 500
2211        assert!((cardinality - 500.0).abs() < 1.0);
2212    }
2213
2214    #[test]
2215    fn test_custom_distinct_fraction() {
2216        let config = SelectivityConfig {
2217            distinct_fraction: 0.8,
2218            ..SelectivityConfig::new()
2219        };
2220        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2221        estimator.add_table_stats("Person", TableStats::new(1000));
2222
2223        let distinct = LogicalOperator::Distinct(DistinctOp {
2224            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2225                variable: "n".to_string(),
2226                label: Some("Person".to_string()),
2227                input: None,
2228            })),
2229            columns: None,
2230        });
2231
2232        let cardinality = estimator.estimate(&distinct);
2233        // 1000 * 0.8 = 800
2234        assert!((cardinality - 800.0).abs() < 1.0);
2235    }
2236
2237    // ==================== EstimationLog Tests ====================
2238
2239    #[test]
2240    fn test_estimation_log_basic() {
2241        let mut log = EstimationLog::new(10.0);
2242        log.record("NodeScan(Person)", 1000.0, 1200.0);
2243        log.record("Filter(age > 30)", 100.0, 90.0);
2244
2245        assert_eq!(log.entries().len(), 2);
2246        assert!(!log.should_replan()); // 1.2x and 0.9x are within 10x threshold
2247    }
2248
2249    #[test]
2250    fn test_estimation_log_triggers_replan() {
2251        let mut log = EstimationLog::new(10.0);
2252        log.record("NodeScan(Person)", 100.0, 5000.0); // 50x underestimate
2253
2254        assert!(log.should_replan());
2255    }
2256
2257    #[test]
2258    fn test_estimation_log_overestimate_triggers_replan() {
2259        let mut log = EstimationLog::new(5.0);
2260        log.record("Filter", 1000.0, 100.0); // 10x overestimate → ratio = 0.1
2261
2262        assert!(log.should_replan()); // 0.1 < 1/5 = 0.2
2263    }
2264
2265    #[test]
2266    fn test_estimation_entry_error_ratio() {
2267        let entry = EstimationEntry {
2268            operator: "test".into(),
2269            estimated: 100.0,
2270            actual: 200.0,
2271        };
2272        assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2273
2274        let perfect = EstimationEntry {
2275            operator: "test".into(),
2276            estimated: 100.0,
2277            actual: 100.0,
2278        };
2279        assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2280
2281        let zero_est = EstimationEntry {
2282            operator: "test".into(),
2283            estimated: 0.0,
2284            actual: 0.0,
2285        };
2286        assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2287    }
2288
2289    #[test]
2290    fn test_estimation_log_max_error_ratio() {
2291        let mut log = EstimationLog::new(10.0);
2292        log.record("A", 100.0, 300.0); // 3x
2293        log.record("B", 100.0, 50.0); // 2x (normalized: 1/0.5 = 2)
2294        log.record("C", 100.0, 100.0); // 1x
2295
2296        assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2297    }
2298
2299    #[test]
2300    fn test_estimation_log_clear() {
2301        let mut log = EstimationLog::new(10.0);
2302        log.record("A", 100.0, 100.0);
2303        assert_eq!(log.entries().len(), 1);
2304
2305        log.clear();
2306        assert!(log.entries().is_empty());
2307        assert!(!log.should_replan());
2308    }
2309
2310    #[test]
2311    fn test_create_estimation_log() {
2312        let log = CardinalityEstimator::create_estimation_log();
2313        assert!(log.entries().is_empty());
2314        assert!(!log.should_replan());
2315    }
2316}