Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4//!
5//! # Equi-Depth Histograms
6//!
7//! This module provides equi-depth histogram support for accurate selectivity
8//! estimation of range predicates. Unlike equi-width histograms that divide
9//! the value range into equal-sized buckets, equi-depth histograms divide
10//! the data into buckets with approximately equal numbers of rows.
11//!
12//! Benefits:
13//! - Better estimates for skewed data distributions
14//! - More accurate range selectivity than assuming uniform distribution
15//! - Adaptive to actual data characteristics
16
17use crate::query::plan::{
18    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19    LogicalExpression, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, SkipOp, SortOp,
20    UnaryOp, VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24/// A bucket in an equi-depth histogram.
25///
26/// Each bucket represents a range of values and the frequency of rows
27/// falling within that range. In an equi-depth histogram, all buckets
28/// contain approximately the same number of rows.
29#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31    /// Lower bound of the bucket (inclusive).
32    pub lower_bound: f64,
33    /// Upper bound of the bucket (exclusive, except for the last bucket).
34    pub upper_bound: f64,
35    /// Number of rows in this bucket.
36    pub frequency: u64,
37    /// Number of distinct values in this bucket.
38    pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42    /// Creates a new histogram bucket.
43    #[must_use]
44    pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45        Self {
46            lower_bound,
47            upper_bound,
48            frequency,
49            distinct_count,
50        }
51    }
52
53    /// Returns the width of this bucket.
54    #[must_use]
55    pub fn width(&self) -> f64 {
56        self.upper_bound - self.lower_bound
57    }
58
59    /// Checks if a value falls within this bucket.
60    #[must_use]
61    pub fn contains(&self, value: f64) -> bool {
62        value >= self.lower_bound && value < self.upper_bound
63    }
64
65    /// Returns the fraction of this bucket covered by the given range.
66    #[must_use]
67    pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68        let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69        let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71        let bucket_width = self.width();
72        if bucket_width <= 0.0 {
73            return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74                1.0
75            } else {
76                0.0
77            };
78        }
79
80        let overlap = (effective_upper - effective_lower).max(0.0);
81        (overlap / bucket_width).min(1.0)
82    }
83}
84
85/// An equi-depth histogram for selectivity estimation.
86///
87/// Equi-depth histograms partition data into buckets where each bucket
88/// contains approximately the same number of rows. This provides more
89/// accurate selectivity estimates than assuming uniform distribution,
90/// especially for skewed data.
91///
92/// # Example
93///
94/// ```no_run
95/// use grafeo_engine::query::optimizer::cardinality::EquiDepthHistogram;
96///
97/// // Build a histogram from sorted values
98/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0];
99/// let histogram = EquiDepthHistogram::build(&values, 4);
100///
101/// // Estimate selectivity for age > 25
102/// let selectivity = histogram.range_selectivity(Some(25.0), None);
103/// ```
104#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106    /// The histogram buckets, sorted by lower_bound.
107    buckets: Vec<HistogramBucket>,
108    /// Total number of rows represented by this histogram.
109    total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113    /// Creates a new histogram from pre-built buckets.
114    #[must_use]
115    pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116        let total_rows = buckets.iter().map(|b| b.frequency).sum();
117        Self {
118            buckets,
119            total_rows,
120        }
121    }
122
123    /// Builds an equi-depth histogram from a sorted slice of values.
124    ///
125    /// # Arguments
126    /// * `values` - A sorted slice of numeric values
127    /// * `num_buckets` - The desired number of buckets
128    ///
129    /// # Returns
130    /// An equi-depth histogram with approximately equal row counts per bucket.
131    #[must_use]
132    pub fn build(values: &[f64], num_buckets: usize) -> Self {
133        if values.is_empty() || num_buckets == 0 {
134            return Self {
135                buckets: Vec::new(),
136                total_rows: 0,
137            };
138        }
139
140        let num_buckets = num_buckets.min(values.len());
141        let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142        let mut buckets = Vec::with_capacity(num_buckets);
143
144        let mut start_idx = 0;
145        while start_idx < values.len() {
146            let end_idx = (start_idx + rows_per_bucket).min(values.len());
147            let lower_bound = values[start_idx];
148            let upper_bound = if end_idx < values.len() {
149                values[end_idx]
150            } else {
151                // For the last bucket, extend slightly beyond the max value
152                values[end_idx - 1] + 1.0
153            };
154
155            // Count distinct values in this bucket
156            let bucket_values = &values[start_idx..end_idx];
157            let distinct_count = count_distinct(bucket_values);
158
159            buckets.push(HistogramBucket::new(
160                lower_bound,
161                upper_bound,
162                (end_idx - start_idx) as u64,
163                distinct_count,
164            ));
165
166            start_idx = end_idx;
167        }
168
169        Self::new(buckets)
170    }
171
172    /// Returns the number of buckets in this histogram.
173    #[must_use]
174    pub fn num_buckets(&self) -> usize {
175        self.buckets.len()
176    }
177
178    /// Returns the total number of rows represented.
179    #[must_use]
180    pub fn total_rows(&self) -> u64 {
181        self.total_rows
182    }
183
184    /// Returns the histogram buckets.
185    #[must_use]
186    pub fn buckets(&self) -> &[HistogramBucket] {
187        &self.buckets
188    }
189
190    /// Estimates selectivity for a range predicate.
191    ///
192    /// # Arguments
193    /// * `lower` - Lower bound (None for unbounded)
194    /// * `upper` - Upper bound (None for unbounded)
195    ///
196    /// # Returns
197    /// Estimated fraction of rows matching the range (0.0 to 1.0).
198    #[must_use]
199    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200        if self.buckets.is_empty() || self.total_rows == 0 {
201            return 0.33; // Default fallback
202        }
203
204        let mut matching_rows = 0.0;
205
206        for bucket in &self.buckets {
207            // Check if this bucket overlaps with the range
208            let bucket_lower = bucket.lower_bound;
209            let bucket_upper = bucket.upper_bound;
210
211            // Skip buckets entirely outside the range
212            if let Some(l) = lower
213                && bucket_upper <= l
214            {
215                continue;
216            }
217            if let Some(u) = upper
218                && bucket_lower >= u
219            {
220                continue;
221            }
222
223            // Calculate the fraction of this bucket covered by the range
224            let overlap = bucket.overlap_fraction(lower, upper);
225            matching_rows += overlap * bucket.frequency as f64;
226        }
227
228        (matching_rows / self.total_rows as f64).clamp(0.0, 1.0)
229    }
230
231    /// Estimates selectivity for an equality predicate.
232    ///
233    /// Uses the distinct count within matching buckets for better accuracy.
234    #[must_use]
235    pub fn equality_selectivity(&self, value: f64) -> f64 {
236        if self.buckets.is_empty() || self.total_rows == 0 {
237            return 0.01; // Default fallback
238        }
239
240        // Find the bucket containing this value
241        for bucket in &self.buckets {
242            if bucket.contains(value) {
243                // Assume uniform distribution within bucket
244                if bucket.distinct_count > 0 {
245                    return (bucket.frequency as f64
246                        / bucket.distinct_count as f64
247                        / self.total_rows as f64)
248                        .min(1.0);
249                }
250            }
251        }
252
253        // Value not in any bucket - very low selectivity
254        0.001
255    }
256
257    /// Gets the minimum value in the histogram.
258    #[must_use]
259    pub fn min_value(&self) -> Option<f64> {
260        self.buckets.first().map(|b| b.lower_bound)
261    }
262
263    /// Gets the maximum value in the histogram.
264    #[must_use]
265    pub fn max_value(&self) -> Option<f64> {
266        self.buckets.last().map(|b| b.upper_bound)
267    }
268}
269
270/// Counts distinct values in a sorted slice.
271fn count_distinct(sorted_values: &[f64]) -> u64 {
272    if sorted_values.is_empty() {
273        return 0;
274    }
275
276    let mut count = 1u64;
277    let mut prev = sorted_values[0];
278
279    for &val in &sorted_values[1..] {
280        if (val - prev).abs() > f64::EPSILON {
281            count += 1;
282            prev = val;
283        }
284    }
285
286    count
287}
288
289/// Statistics for a table/label.
290#[derive(Debug, Clone)]
291pub struct TableStats {
292    /// Total number of rows.
293    pub row_count: u64,
294    /// Column statistics.
295    pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299    /// Creates new table statistics.
300    #[must_use]
301    pub fn new(row_count: u64) -> Self {
302        Self {
303            row_count,
304            columns: HashMap::new(),
305        }
306    }
307
308    /// Adds column statistics.
309    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310        self.columns.insert(name.to_string(), stats);
311        self
312    }
313}
314
315/// Statistics for a column.
316#[derive(Debug, Clone)]
317pub struct ColumnStats {
318    /// Number of distinct values.
319    pub distinct_count: u64,
320    /// Number of null values.
321    pub null_count: u64,
322    /// Minimum value (if orderable).
323    pub min_value: Option<f64>,
324    /// Maximum value (if orderable).
325    pub max_value: Option<f64>,
326    /// Equi-depth histogram for accurate selectivity estimation.
327    pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331    /// Creates new column statistics.
332    #[must_use]
333    pub fn new(distinct_count: u64) -> Self {
334        Self {
335            distinct_count,
336            null_count: 0,
337            min_value: None,
338            max_value: None,
339            histogram: None,
340        }
341    }
342
343    /// Sets the null count.
344    #[must_use]
345    pub fn with_nulls(mut self, null_count: u64) -> Self {
346        self.null_count = null_count;
347        self
348    }
349
350    /// Sets the min/max range.
351    #[must_use]
352    pub fn with_range(mut self, min: f64, max: f64) -> Self {
353        self.min_value = Some(min);
354        self.max_value = Some(max);
355        self
356    }
357
358    /// Sets the equi-depth histogram for this column.
359    #[must_use]
360    pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361        self.histogram = Some(histogram);
362        self
363    }
364
365    /// Builds column statistics with histogram from raw values.
366    ///
367    /// This is a convenience method that computes all statistics from the data.
368    ///
369    /// # Arguments
370    /// * `values` - The column values (will be sorted internally)
371    /// * `num_buckets` - Number of histogram buckets to create
372    #[must_use]
373    pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374        if values.is_empty() {
375            return Self::new(0);
376        }
377
378        // Sort values for histogram building
379        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381        let min = values.first().copied();
382        let max = values.last().copied();
383        let distinct_count = count_distinct(&values);
384        let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386        Self {
387            distinct_count,
388            null_count: 0,
389            min_value: min,
390            max_value: max,
391            histogram: Some(histogram),
392        }
393    }
394}
395
396/// Configurable selectivity defaults for cardinality estimation.
397///
398/// Controls the assumed selectivity for various predicate types when
399/// histogram or column statistics are unavailable. Adjusting these
400/// values can improve plan quality for workloads with known skew.
401#[derive(Debug, Clone)]
402pub struct SelectivityConfig {
403    /// Selectivity for unknown predicates (default: 0.1).
404    pub default: f64,
405    /// Selectivity for equality predicates without stats (default: 0.01).
406    pub equality: f64,
407    /// Selectivity for inequality predicates (default: 0.99).
408    pub inequality: f64,
409    /// Selectivity for range predicates without stats (default: 0.33).
410    pub range: f64,
411    /// Selectivity for string operations: STARTS WITH, ENDS WITH, CONTAINS, LIKE (default: 0.1).
412    pub string_ops: f64,
413    /// Selectivity for IN membership (default: 0.1).
414    pub membership: f64,
415    /// Selectivity for IS NULL (default: 0.05).
416    pub is_null: f64,
417    /// Selectivity for IS NOT NULL (default: 0.95).
418    pub is_not_null: f64,
419    /// Fraction assumed distinct for DISTINCT operations (default: 0.5).
420    pub distinct_fraction: f64,
421}
422
423impl SelectivityConfig {
424    /// Creates a new config with standard database defaults.
425    #[must_use]
426    pub fn new() -> Self {
427        Self {
428            default: 0.1,
429            equality: 0.01,
430            inequality: 0.99,
431            range: 0.33,
432            string_ops: 0.1,
433            membership: 0.1,
434            is_null: 0.05,
435            is_not_null: 0.95,
436            distinct_fraction: 0.5,
437        }
438    }
439}
440
441impl Default for SelectivityConfig {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447/// A single estimate-vs-actual observation for analysis.
448#[derive(Debug, Clone)]
449pub struct EstimationEntry {
450    /// Human-readable label for the operator (e.g., "NodeScan(Person)").
451    pub operator: String,
452    /// The cardinality estimate produced by the optimizer.
453    pub estimated: f64,
454    /// The actual row count observed at execution time.
455    pub actual: f64,
456}
457
458impl EstimationEntry {
459    /// Returns the estimation error ratio (actual / estimated).
460    ///
461    /// Values near 1.0 indicate accurate estimates.
462    /// Values > 1.0 indicate underestimation.
463    /// Values < 1.0 indicate overestimation.
464    #[must_use]
465    pub fn error_ratio(&self) -> f64 {
466        if self.estimated.abs() < f64::EPSILON {
467            if self.actual.abs() < f64::EPSILON {
468                1.0
469            } else {
470                f64::INFINITY
471            }
472        } else {
473            self.actual / self.estimated
474        }
475    }
476}
477
478/// Collects estimate vs actual cardinality data for query plan analysis.
479///
480/// After executing a query, call [`record()`](Self::record) for each
481/// operator with its estimated and actual cardinalities. Then use
482/// [`should_replan()`](Self::should_replan) to decide whether the plan
483/// should be re-optimized.
484#[derive(Debug, Clone, Default)]
485pub struct EstimationLog {
486    /// Recorded entries.
487    entries: Vec<EstimationEntry>,
488    /// Error ratio threshold that triggers re-planning (default: 10.0).
489    ///
490    /// If any operator's error ratio exceeds this, `should_replan()` returns true.
491    replan_threshold: f64,
492}
493
494impl EstimationLog {
495    /// Creates a new estimation log with the given re-planning threshold.
496    #[must_use]
497    pub fn new(replan_threshold: f64) -> Self {
498        Self {
499            entries: Vec::new(),
500            replan_threshold,
501        }
502    }
503
504    /// Records an estimate-vs-actual observation.
505    pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
506        self.entries.push(EstimationEntry {
507            operator: operator.into(),
508            estimated,
509            actual,
510        });
511    }
512
513    /// Returns all recorded entries.
514    #[must_use]
515    pub fn entries(&self) -> &[EstimationEntry] {
516        &self.entries
517    }
518
519    /// Returns whether any operator's estimation error exceeds the threshold,
520    /// indicating the plan should be re-optimized.
521    #[must_use]
522    pub fn should_replan(&self) -> bool {
523        self.entries.iter().any(|e| {
524            let ratio = e.error_ratio();
525            ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
526        })
527    }
528
529    /// Returns the maximum error ratio across all entries.
530    #[must_use]
531    pub fn max_error_ratio(&self) -> f64 {
532        self.entries
533            .iter()
534            .map(|e| {
535                let r = e.error_ratio();
536                // Normalize so both over- and under-estimation are > 1.0
537                if r < 1.0 { 1.0 / r } else { r }
538            })
539            .fold(1.0_f64, f64::max)
540    }
541
542    /// Clears all entries.
543    pub fn clear(&mut self) {
544        self.entries.clear();
545    }
546}
547
548/// Cardinality estimator.
549pub struct CardinalityEstimator {
550    /// Statistics for each label/table.
551    table_stats: HashMap<String, TableStats>,
552    /// Default row count for unknown tables.
553    default_row_count: u64,
554    /// Default selectivity for unknown predicates.
555    default_selectivity: f64,
556    /// Average edge fanout (outgoing edges per node).
557    avg_fanout: f64,
558    /// Configurable selectivity defaults.
559    selectivity_config: SelectivityConfig,
560}
561
562impl CardinalityEstimator {
563    /// Creates a new cardinality estimator with default settings.
564    #[must_use]
565    pub fn new() -> Self {
566        let config = SelectivityConfig::new();
567        Self {
568            table_stats: HashMap::new(),
569            default_row_count: 1000,
570            default_selectivity: config.default,
571            avg_fanout: 10.0,
572            selectivity_config: config,
573        }
574    }
575
576    /// Creates a new cardinality estimator with custom selectivity configuration.
577    #[must_use]
578    pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
579        Self {
580            table_stats: HashMap::new(),
581            default_row_count: 1000,
582            default_selectivity: config.default,
583            avg_fanout: 10.0,
584            selectivity_config: config,
585        }
586    }
587
588    /// Returns the current selectivity configuration.
589    #[must_use]
590    pub fn selectivity_config(&self) -> &SelectivityConfig {
591        &self.selectivity_config
592    }
593
594    /// Creates an estimation log with the default re-planning threshold (10x).
595    #[must_use]
596    pub fn create_estimation_log() -> EstimationLog {
597        EstimationLog::new(10.0)
598    }
599
600    /// Creates a cardinality estimator pre-populated from store statistics.
601    ///
602    /// Maps `LabelStatistics` to `TableStats` and computes the average edge
603    /// fanout from `EdgeTypeStatistics`. Falls back to defaults for any
604    /// missing statistics.
605    #[must_use]
606    pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
607        let mut estimator = Self::new();
608
609        // Use total node count as default for unlabeled scans
610        if stats.total_nodes > 0 {
611            estimator.default_row_count = stats.total_nodes;
612        }
613
614        // Convert label statistics to optimizer table stats
615        for (label, label_stats) in &stats.labels {
616            let mut table_stats = TableStats::new(label_stats.node_count);
617
618            // Map property statistics (distinct count for selectivity estimation)
619            for (prop, col_stats) in &label_stats.properties {
620                let optimizer_col =
621                    ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
622                table_stats = table_stats.with_column(prop, optimizer_col);
623            }
624
625            estimator.add_table_stats(label, table_stats);
626        }
627
628        // Compute average fanout from edge type statistics
629        if !stats.edge_types.is_empty() {
630            let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
631            estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
632        } else if stats.total_nodes > 0 {
633            estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
634        }
635
636        // Clamp fanout to a reasonable minimum
637        if estimator.avg_fanout < 1.0 {
638            estimator.avg_fanout = 1.0;
639        }
640
641        estimator
642    }
643
644    /// Adds statistics for a table/label.
645    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
646        self.table_stats.insert(name.to_string(), stats);
647    }
648
649    /// Sets the average edge fanout.
650    pub fn set_avg_fanout(&mut self, fanout: f64) {
651        self.avg_fanout = fanout;
652    }
653
654    /// Estimates the cardinality of a logical operator.
655    #[must_use]
656    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
657        match op {
658            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
659            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
660            LogicalOperator::Project(project) => self.estimate_project(project),
661            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
662            LogicalOperator::Join(join) => self.estimate_join(join),
663            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
664            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
665            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
666            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
667            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
668            LogicalOperator::Return(ret) => self.estimate(&ret.input),
669            LogicalOperator::Empty => 0.0,
670            LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
671            LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
672            LogicalOperator::MultiWayJoin(mwj) => self.estimate_multi_way_join(mwj),
673            _ => self.default_row_count as f64,
674        }
675    }
676
677    /// Estimates node scan cardinality.
678    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
679        if let Some(label) = &scan.label
680            && let Some(stats) = self.table_stats.get(label)
681        {
682            return stats.row_count as f64;
683        }
684        // No label filter - scan all nodes
685        self.default_row_count as f64
686    }
687
688    /// Estimates filter cardinality.
689    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
690        let input_cardinality = self.estimate(&filter.input);
691        let selectivity = self.estimate_selectivity(&filter.predicate);
692        (input_cardinality * selectivity).max(1.0)
693    }
694
695    /// Estimates projection cardinality (same as input).
696    fn estimate_project(&self, project: &ProjectOp) -> f64 {
697        self.estimate(&project.input)
698    }
699
700    /// Estimates expand cardinality.
701    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
702        let input_cardinality = self.estimate(&expand.input);
703
704        // Apply fanout based on edge type
705        let fanout = if !expand.edge_types.is_empty() {
706            // Specific edge type(s) typically have lower fanout
707            self.avg_fanout * 0.5
708        } else {
709            self.avg_fanout
710        };
711
712        // Handle variable-length paths
713        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
714            let min = expand.min_hops as f64;
715            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
716            // Geometric series approximation
717            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
718        } else {
719            fanout
720        };
721
722        (input_cardinality * path_multiplier).max(1.0)
723    }
724
725    /// Estimates join cardinality.
726    fn estimate_join(&self, join: &JoinOp) -> f64 {
727        let left_card = self.estimate(&join.left);
728        let right_card = self.estimate(&join.right);
729
730        match join.join_type {
731            JoinType::Cross => left_card * right_card,
732            JoinType::Inner => {
733                // Assume join selectivity based on conditions
734                let selectivity = if join.conditions.is_empty() {
735                    1.0 // Cross join
736                } else {
737                    // Estimate based on number of conditions
738                    0.1_f64.powi(join.conditions.len() as i32)
739                };
740                (left_card * right_card * selectivity).max(1.0)
741            }
742            JoinType::Left => {
743                // Left join returns at least all left rows
744                let inner_card = self.estimate_join(&JoinOp {
745                    left: join.left.clone(),
746                    right: join.right.clone(),
747                    join_type: JoinType::Inner,
748                    conditions: join.conditions.clone(),
749                });
750                inner_card.max(left_card)
751            }
752            JoinType::Right => {
753                // Right join returns at least all right rows
754                let inner_card = self.estimate_join(&JoinOp {
755                    left: join.left.clone(),
756                    right: join.right.clone(),
757                    join_type: JoinType::Inner,
758                    conditions: join.conditions.clone(),
759                });
760                inner_card.max(right_card)
761            }
762            JoinType::Full => {
763                // Full join returns at least max(left, right)
764                let inner_card = self.estimate_join(&JoinOp {
765                    left: join.left.clone(),
766                    right: join.right.clone(),
767                    join_type: JoinType::Inner,
768                    conditions: join.conditions.clone(),
769                });
770                inner_card.max(left_card.max(right_card))
771            }
772            JoinType::Semi => {
773                // Semi join returns at most left cardinality
774                (left_card * self.default_selectivity).max(1.0)
775            }
776            JoinType::Anti => {
777                // Anti join returns at most left cardinality
778                (left_card * (1.0 - self.default_selectivity)).max(1.0)
779            }
780        }
781    }
782
783    /// Estimates aggregation cardinality.
784    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
785        let input_cardinality = self.estimate(&agg.input);
786
787        if agg.group_by.is_empty() {
788            // Global aggregation - single row
789            1.0
790        } else {
791            // Group by - estimate distinct groups
792            // Assume each group key reduces cardinality by 10
793            let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
794            (input_cardinality / group_reduction).max(1.0)
795        }
796    }
797
798    /// Estimates sort cardinality (same as input).
799    fn estimate_sort(&self, sort: &SortOp) -> f64 {
800        self.estimate(&sort.input)
801    }
802
803    /// Estimates distinct cardinality.
804    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
805        let input_cardinality = self.estimate(&distinct.input);
806        (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
807    }
808
809    /// Estimates limit cardinality.
810    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
811        let input_cardinality = self.estimate(&limit.input);
812        limit.count.estimate().min(input_cardinality)
813    }
814
815    /// Estimates skip cardinality.
816    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
817        let input_cardinality = self.estimate(&skip.input);
818        (input_cardinality - skip.count.estimate()).max(0.0)
819    }
820
821    /// Estimates vector scan cardinality.
822    ///
823    /// Vector scan returns at most k results (the k nearest neighbors).
824    /// With similarity/distance filters, it may return fewer.
825    fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
826        let base_k = scan.k as f64;
827
828        // Apply filter selectivity if thresholds are specified
829        let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
830            // Assume 70% of results pass threshold filters
831            0.7
832        } else {
833            1.0
834        };
835
836        (base_k * selectivity).max(1.0)
837    }
838
839    /// Estimates vector join cardinality.
840    ///
841    /// Vector join produces up to k results per input row.
842    fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
843        let input_cardinality = self.estimate(&join.input);
844        let k = join.k as f64;
845
846        // Apply filter selectivity if thresholds are specified
847        let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
848            0.7
849        } else {
850            1.0
851        };
852
853        (input_cardinality * k * selectivity).max(1.0)
854    }
855
856    /// Estimates multi-way join cardinality using the AGM bound heuristic.
857    ///
858    /// For a cyclic join of N relations, the AGM (Atserias-Grohe-Marx) bound
859    /// gives min(cardinality)^(N/2) as a worst-case output size estimate.
860    fn estimate_multi_way_join(&self, mwj: &MultiWayJoinOp) -> f64 {
861        if mwj.inputs.is_empty() {
862            return 0.0;
863        }
864        let cardinalities: Vec<f64> = mwj
865            .inputs
866            .iter()
867            .map(|input| self.estimate(input))
868            .collect();
869        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
870        let n = cardinalities.len() as f64;
871        // AGM bound: min(cardinality)^(n/2)
872        (min_card.powf(n / 2.0)).max(1.0)
873    }
874
875    /// Estimates the selectivity of a predicate (0.0 to 1.0).
876    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
877        match expr {
878            LogicalExpression::Binary { left, op, right } => {
879                self.estimate_binary_selectivity(left, *op, right)
880            }
881            LogicalExpression::Unary { op, operand } => {
882                self.estimate_unary_selectivity(*op, operand)
883            }
884            LogicalExpression::Literal(value) => {
885                // Boolean literal
886                if let grafeo_common::types::Value::Bool(b) = value {
887                    if *b { 1.0 } else { 0.0 }
888                } else {
889                    self.default_selectivity
890                }
891            }
892            _ => self.default_selectivity,
893        }
894    }
895
896    /// Estimates binary expression selectivity.
897    fn estimate_binary_selectivity(
898        &self,
899        left: &LogicalExpression,
900        op: BinaryOp,
901        right: &LogicalExpression,
902    ) -> f64 {
903        match op {
904            // Equality - try histogram-based estimation
905            BinaryOp::Eq => {
906                if let Some(selectivity) = self.try_equality_selectivity(left, right) {
907                    return selectivity;
908                }
909                self.selectivity_config.equality
910            }
911            // Inequality is very unselective
912            BinaryOp::Ne => self.selectivity_config.inequality,
913            // Range predicates - use histogram if available
914            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
915                if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
916                    return selectivity;
917                }
918                self.selectivity_config.range
919            }
920            // Logical operators - recursively estimate sub-expressions
921            BinaryOp::And => {
922                let left_sel = self.estimate_selectivity(left);
923                let right_sel = self.estimate_selectivity(right);
924                // AND reduces selectivity (multiply assuming independence)
925                left_sel * right_sel
926            }
927            BinaryOp::Or => {
928                let left_sel = self.estimate_selectivity(left);
929                let right_sel = self.estimate_selectivity(right);
930                // OR: P(A ∪ B) = P(A) + P(B) - P(A ∩ B)
931                // Assuming independence: P(A ∩ B) = P(A) * P(B)
932                (left_sel + right_sel - left_sel * right_sel).min(1.0)
933            }
934            // String operations
935            BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
936                self.selectivity_config.string_ops
937            }
938            // Collection membership
939            BinaryOp::In => self.selectivity_config.membership,
940            // Other operations
941            _ => self.default_selectivity,
942        }
943    }
944
945    /// Tries to estimate equality selectivity using histograms.
946    fn try_equality_selectivity(
947        &self,
948        left: &LogicalExpression,
949        right: &LogicalExpression,
950    ) -> Option<f64> {
951        // Extract property access and literal value
952        let (label, column, value) = self.extract_column_and_value(left, right)?;
953
954        // Get column stats with histogram
955        let stats = self.get_column_stats(&label, &column)?;
956
957        // Try histogram-based estimation
958        if let Some(ref histogram) = stats.histogram {
959            return Some(histogram.equality_selectivity(value));
960        }
961
962        // Fall back to distinct count estimation
963        if stats.distinct_count > 0 {
964            return Some(1.0 / stats.distinct_count as f64);
965        }
966
967        None
968    }
969
970    /// Tries to estimate range selectivity using histograms.
971    fn try_range_selectivity(
972        &self,
973        left: &LogicalExpression,
974        op: BinaryOp,
975        right: &LogicalExpression,
976    ) -> Option<f64> {
977        // Extract property access and literal value
978        let (label, column, value) = self.extract_column_and_value(left, right)?;
979
980        // Get column stats
981        let stats = self.get_column_stats(&label, &column)?;
982
983        // Determine the range based on operator
984        let (lower, upper) = match op {
985            BinaryOp::Lt => (None, Some(value)),
986            BinaryOp::Le => (None, Some(value + f64::EPSILON)),
987            BinaryOp::Gt => (Some(value + f64::EPSILON), None),
988            BinaryOp::Ge => (Some(value), None),
989            _ => return None,
990        };
991
992        // Try histogram-based estimation first
993        if let Some(ref histogram) = stats.histogram {
994            return Some(histogram.range_selectivity(lower, upper));
995        }
996
997        // Fall back to min/max range estimation
998        if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
999            let range = max - min;
1000            if range <= 0.0 {
1001                return Some(1.0);
1002            }
1003
1004            let effective_lower = lower.unwrap_or(min).max(min);
1005            let effective_upper = upper.unwrap_or(max).min(max);
1006            let overlap = (effective_upper - effective_lower).max(0.0);
1007            return Some((overlap / range).clamp(0.0, 1.0));
1008        }
1009
1010        None
1011    }
1012
1013    /// Extracts column information and literal value from a comparison.
1014    ///
1015    /// Returns (label, column_name, numeric_value) if the expression is
1016    /// a comparison between a property access and a numeric literal.
1017    fn extract_column_and_value(
1018        &self,
1019        left: &LogicalExpression,
1020        right: &LogicalExpression,
1021    ) -> Option<(String, String, f64)> {
1022        // Try left as property, right as literal
1023        if let Some(result) = self.try_extract_property_literal(left, right) {
1024            return Some(result);
1025        }
1026
1027        // Try right as property, left as literal
1028        self.try_extract_property_literal(right, left)
1029    }
1030
1031    /// Tries to extract property and literal from a specific ordering.
1032    fn try_extract_property_literal(
1033        &self,
1034        property_expr: &LogicalExpression,
1035        literal_expr: &LogicalExpression,
1036    ) -> Option<(String, String, f64)> {
1037        // Extract property access
1038        let (variable, property) = match property_expr {
1039            LogicalExpression::Property { variable, property } => {
1040                (variable.clone(), property.clone())
1041            }
1042            _ => return None,
1043        };
1044
1045        // Extract numeric literal
1046        let value = match literal_expr {
1047            LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1048            LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1049            _ => return None,
1050        };
1051
1052        // Try to find a label for this variable from table stats
1053        // Use the variable name as a heuristic label lookup
1054        // In practice, the optimizer would track which labels variables are bound to
1055        for label in self.table_stats.keys() {
1056            if let Some(stats) = self.table_stats.get(label)
1057                && stats.columns.contains_key(&property)
1058            {
1059                return Some((label.clone(), property, value));
1060            }
1061        }
1062
1063        // If no stats found but we have the property, return with variable as label
1064        Some((variable, property, value))
1065    }
1066
1067    /// Estimates unary expression selectivity.
1068    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1069        match op {
1070            UnaryOp::Not => 1.0 - self.default_selectivity,
1071            UnaryOp::IsNull => self.selectivity_config.is_null,
1072            UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1073            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
1074        }
1075    }
1076
1077    /// Gets statistics for a column.
1078    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1079        self.table_stats.get(label)?.columns.get(column)
1080    }
1081}
1082
1083impl Default for CardinalityEstimator {
1084    fn default() -> Self {
1085        Self::new()
1086    }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091    use super::*;
1092    use crate::query::plan::{
1093        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1094        ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1095    };
1096    use grafeo_common::types::Value;
1097
1098    #[test]
1099    fn test_node_scan_with_stats() {
1100        let mut estimator = CardinalityEstimator::new();
1101        estimator.add_table_stats("Person", TableStats::new(5000));
1102
1103        let scan = LogicalOperator::NodeScan(NodeScanOp {
1104            variable: "n".to_string(),
1105            label: Some("Person".to_string()),
1106            input: None,
1107        });
1108
1109        let cardinality = estimator.estimate(&scan);
1110        assert!((cardinality - 5000.0).abs() < 0.001);
1111    }
1112
1113    #[test]
1114    fn test_filter_reduces_cardinality() {
1115        let mut estimator = CardinalityEstimator::new();
1116        estimator.add_table_stats("Person", TableStats::new(1000));
1117
1118        let filter = LogicalOperator::Filter(FilterOp {
1119            predicate: LogicalExpression::Binary {
1120                left: Box::new(LogicalExpression::Property {
1121                    variable: "n".to_string(),
1122                    property: "age".to_string(),
1123                }),
1124                op: BinaryOp::Eq,
1125                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1126            },
1127            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1128                variable: "n".to_string(),
1129                label: Some("Person".to_string()),
1130                input: None,
1131            })),
1132            pushdown_hint: None,
1133        });
1134
1135        let cardinality = estimator.estimate(&filter);
1136        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
1137        assert!(cardinality < 1000.0);
1138        assert!(cardinality >= 1.0);
1139    }
1140
1141    #[test]
1142    fn test_join_cardinality() {
1143        let mut estimator = CardinalityEstimator::new();
1144        estimator.add_table_stats("Person", TableStats::new(1000));
1145        estimator.add_table_stats("Company", TableStats::new(100));
1146
1147        let join = LogicalOperator::Join(JoinOp {
1148            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1149                variable: "p".to_string(),
1150                label: Some("Person".to_string()),
1151                input: None,
1152            })),
1153            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1154                variable: "c".to_string(),
1155                label: Some("Company".to_string()),
1156                input: None,
1157            })),
1158            join_type: JoinType::Inner,
1159            conditions: vec![JoinCondition {
1160                left: LogicalExpression::Property {
1161                    variable: "p".to_string(),
1162                    property: "company_id".to_string(),
1163                },
1164                right: LogicalExpression::Property {
1165                    variable: "c".to_string(),
1166                    property: "id".to_string(),
1167                },
1168            }],
1169        });
1170
1171        let cardinality = estimator.estimate(&join);
1172        // Should be less than cross product
1173        assert!(cardinality < 1000.0 * 100.0);
1174    }
1175
1176    #[test]
1177    fn test_limit_caps_cardinality() {
1178        let mut estimator = CardinalityEstimator::new();
1179        estimator.add_table_stats("Person", TableStats::new(1000));
1180
1181        let limit = LogicalOperator::Limit(LimitOp {
1182            count: 10.into(),
1183            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1184                variable: "n".to_string(),
1185                label: Some("Person".to_string()),
1186                input: None,
1187            })),
1188        });
1189
1190        let cardinality = estimator.estimate(&limit);
1191        assert!((cardinality - 10.0).abs() < 0.001);
1192    }
1193
1194    #[test]
1195    fn test_aggregate_reduces_cardinality() {
1196        let mut estimator = CardinalityEstimator::new();
1197        estimator.add_table_stats("Person", TableStats::new(1000));
1198
1199        // Global aggregation
1200        let global_agg = LogicalOperator::Aggregate(AggregateOp {
1201            group_by: vec![],
1202            aggregates: vec![],
1203            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1204                variable: "n".to_string(),
1205                label: Some("Person".to_string()),
1206                input: None,
1207            })),
1208            having: None,
1209        });
1210
1211        let cardinality = estimator.estimate(&global_agg);
1212        assert!((cardinality - 1.0).abs() < 0.001);
1213
1214        // Group by aggregation
1215        let group_agg = LogicalOperator::Aggregate(AggregateOp {
1216            group_by: vec![LogicalExpression::Property {
1217                variable: "n".to_string(),
1218                property: "city".to_string(),
1219            }],
1220            aggregates: vec![],
1221            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1222                variable: "n".to_string(),
1223                label: Some("Person".to_string()),
1224                input: None,
1225            })),
1226            having: None,
1227        });
1228
1229        let cardinality = estimator.estimate(&group_agg);
1230        // Should be less than input
1231        assert!(cardinality < 1000.0);
1232    }
1233
1234    #[test]
1235    fn test_node_scan_without_stats() {
1236        let estimator = CardinalityEstimator::new();
1237
1238        let scan = LogicalOperator::NodeScan(NodeScanOp {
1239            variable: "n".to_string(),
1240            label: Some("Unknown".to_string()),
1241            input: None,
1242        });
1243
1244        let cardinality = estimator.estimate(&scan);
1245        // Should return default (1000)
1246        assert!((cardinality - 1000.0).abs() < 0.001);
1247    }
1248
1249    #[test]
1250    fn test_node_scan_no_label() {
1251        let estimator = CardinalityEstimator::new();
1252
1253        let scan = LogicalOperator::NodeScan(NodeScanOp {
1254            variable: "n".to_string(),
1255            label: None,
1256            input: None,
1257        });
1258
1259        let cardinality = estimator.estimate(&scan);
1260        // Should scan all nodes (default)
1261        assert!((cardinality - 1000.0).abs() < 0.001);
1262    }
1263
1264    #[test]
1265    fn test_filter_inequality_selectivity() {
1266        let mut estimator = CardinalityEstimator::new();
1267        estimator.add_table_stats("Person", TableStats::new(1000));
1268
1269        let filter = LogicalOperator::Filter(FilterOp {
1270            predicate: LogicalExpression::Binary {
1271                left: Box::new(LogicalExpression::Property {
1272                    variable: "n".to_string(),
1273                    property: "age".to_string(),
1274                }),
1275                op: BinaryOp::Ne,
1276                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1277            },
1278            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1279                variable: "n".to_string(),
1280                label: Some("Person".to_string()),
1281                input: None,
1282            })),
1283            pushdown_hint: None,
1284        });
1285
1286        let cardinality = estimator.estimate(&filter);
1287        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
1288        assert!(cardinality > 900.0);
1289    }
1290
1291    #[test]
1292    fn test_filter_range_selectivity() {
1293        let mut estimator = CardinalityEstimator::new();
1294        estimator.add_table_stats("Person", TableStats::new(1000));
1295
1296        let filter = LogicalOperator::Filter(FilterOp {
1297            predicate: LogicalExpression::Binary {
1298                left: Box::new(LogicalExpression::Property {
1299                    variable: "n".to_string(),
1300                    property: "age".to_string(),
1301                }),
1302                op: BinaryOp::Gt,
1303                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1304            },
1305            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1306                variable: "n".to_string(),
1307                label: Some("Person".to_string()),
1308                input: None,
1309            })),
1310            pushdown_hint: None,
1311        });
1312
1313        let cardinality = estimator.estimate(&filter);
1314        // Range selectivity is 0.33, so 1000 * 0.33 = 330
1315        assert!(cardinality < 500.0);
1316        assert!(cardinality > 100.0);
1317    }
1318
1319    #[test]
1320    fn test_filter_and_selectivity() {
1321        let mut estimator = CardinalityEstimator::new();
1322        estimator.add_table_stats("Person", TableStats::new(1000));
1323
1324        // Test AND with two equality predicates
1325        // Each equality has selectivity 0.01, so AND gives 0.01 * 0.01 = 0.0001
1326        let filter = LogicalOperator::Filter(FilterOp {
1327            predicate: LogicalExpression::Binary {
1328                left: Box::new(LogicalExpression::Binary {
1329                    left: Box::new(LogicalExpression::Property {
1330                        variable: "n".to_string(),
1331                        property: "city".to_string(),
1332                    }),
1333                    op: BinaryOp::Eq,
1334                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1335                }),
1336                op: BinaryOp::And,
1337                right: Box::new(LogicalExpression::Binary {
1338                    left: Box::new(LogicalExpression::Property {
1339                        variable: "n".to_string(),
1340                        property: "age".to_string(),
1341                    }),
1342                    op: BinaryOp::Eq,
1343                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1344                }),
1345            },
1346            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1347                variable: "n".to_string(),
1348                label: Some("Person".to_string()),
1349                input: None,
1350            })),
1351            pushdown_hint: None,
1352        });
1353
1354        let cardinality = estimator.estimate(&filter);
1355        // AND reduces selectivity (multiply): 0.01 * 0.01 = 0.0001
1356        // 1000 * 0.0001 = 0.1, min is 1.0
1357        assert!(cardinality < 100.0);
1358        assert!(cardinality >= 1.0);
1359    }
1360
1361    #[test]
1362    fn test_filter_or_selectivity() {
1363        let mut estimator = CardinalityEstimator::new();
1364        estimator.add_table_stats("Person", TableStats::new(1000));
1365
1366        // Test OR with two equality predicates
1367        // Each equality has selectivity 0.01
1368        // OR gives: 0.01 + 0.01 - (0.01 * 0.01) = 0.0199
1369        let filter = LogicalOperator::Filter(FilterOp {
1370            predicate: LogicalExpression::Binary {
1371                left: Box::new(LogicalExpression::Binary {
1372                    left: Box::new(LogicalExpression::Property {
1373                        variable: "n".to_string(),
1374                        property: "city".to_string(),
1375                    }),
1376                    op: BinaryOp::Eq,
1377                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1378                }),
1379                op: BinaryOp::Or,
1380                right: Box::new(LogicalExpression::Binary {
1381                    left: Box::new(LogicalExpression::Property {
1382                        variable: "n".to_string(),
1383                        property: "city".to_string(),
1384                    }),
1385                    op: BinaryOp::Eq,
1386                    right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1387                }),
1388            },
1389            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1390                variable: "n".to_string(),
1391                label: Some("Person".to_string()),
1392                input: None,
1393            })),
1394            pushdown_hint: None,
1395        });
1396
1397        let cardinality = estimator.estimate(&filter);
1398        // OR: 0.01 + 0.01 - 0.0001 ≈ 0.0199, so 1000 * 0.0199 ≈ 19.9
1399        assert!(cardinality < 100.0);
1400        assert!(cardinality >= 1.0);
1401    }
1402
1403    #[test]
1404    fn test_filter_literal_true() {
1405        let mut estimator = CardinalityEstimator::new();
1406        estimator.add_table_stats("Person", TableStats::new(1000));
1407
1408        let filter = LogicalOperator::Filter(FilterOp {
1409            predicate: LogicalExpression::Literal(Value::Bool(true)),
1410            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1411                variable: "n".to_string(),
1412                label: Some("Person".to_string()),
1413                input: None,
1414            })),
1415            pushdown_hint: None,
1416        });
1417
1418        let cardinality = estimator.estimate(&filter);
1419        // Literal true has selectivity 1.0
1420        assert!((cardinality - 1000.0).abs() < 0.001);
1421    }
1422
1423    #[test]
1424    fn test_filter_literal_false() {
1425        let mut estimator = CardinalityEstimator::new();
1426        estimator.add_table_stats("Person", TableStats::new(1000));
1427
1428        let filter = LogicalOperator::Filter(FilterOp {
1429            predicate: LogicalExpression::Literal(Value::Bool(false)),
1430            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1431                variable: "n".to_string(),
1432                label: Some("Person".to_string()),
1433                input: None,
1434            })),
1435            pushdown_hint: None,
1436        });
1437
1438        let cardinality = estimator.estimate(&filter);
1439        // Literal false has selectivity 0.0, but min is 1.0
1440        assert!((cardinality - 1.0).abs() < 0.001);
1441    }
1442
1443    #[test]
1444    fn test_unary_not_selectivity() {
1445        let mut estimator = CardinalityEstimator::new();
1446        estimator.add_table_stats("Person", TableStats::new(1000));
1447
1448        let filter = LogicalOperator::Filter(FilterOp {
1449            predicate: LogicalExpression::Unary {
1450                op: UnaryOp::Not,
1451                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1452            },
1453            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1454                variable: "n".to_string(),
1455                label: Some("Person".to_string()),
1456                input: None,
1457            })),
1458            pushdown_hint: None,
1459        });
1460
1461        let cardinality = estimator.estimate(&filter);
1462        // NOT inverts selectivity
1463        assert!(cardinality < 1000.0);
1464    }
1465
1466    #[test]
1467    fn test_unary_is_null_selectivity() {
1468        let mut estimator = CardinalityEstimator::new();
1469        estimator.add_table_stats("Person", TableStats::new(1000));
1470
1471        let filter = LogicalOperator::Filter(FilterOp {
1472            predicate: LogicalExpression::Unary {
1473                op: UnaryOp::IsNull,
1474                operand: Box::new(LogicalExpression::Variable("x".to_string())),
1475            },
1476            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1477                variable: "n".to_string(),
1478                label: Some("Person".to_string()),
1479                input: None,
1480            })),
1481            pushdown_hint: None,
1482        });
1483
1484        let cardinality = estimator.estimate(&filter);
1485        // IS NULL has selectivity 0.05
1486        assert!(cardinality < 100.0);
1487    }
1488
1489    #[test]
1490    fn test_expand_cardinality() {
1491        let mut estimator = CardinalityEstimator::new();
1492        estimator.add_table_stats("Person", TableStats::new(100));
1493
1494        let expand = LogicalOperator::Expand(ExpandOp {
1495            from_variable: "a".to_string(),
1496            to_variable: "b".to_string(),
1497            edge_variable: None,
1498            direction: ExpandDirection::Outgoing,
1499            edge_types: vec![],
1500            min_hops: 1,
1501            max_hops: Some(1),
1502            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1503                variable: "a".to_string(),
1504                label: Some("Person".to_string()),
1505                input: None,
1506            })),
1507            path_alias: None,
1508            path_mode: PathMode::Walk,
1509        });
1510
1511        let cardinality = estimator.estimate(&expand);
1512        // Expand multiplies by fanout (10)
1513        assert!(cardinality > 100.0);
1514    }
1515
1516    #[test]
1517    fn test_expand_with_edge_type_filter() {
1518        let mut estimator = CardinalityEstimator::new();
1519        estimator.add_table_stats("Person", TableStats::new(100));
1520
1521        let expand = LogicalOperator::Expand(ExpandOp {
1522            from_variable: "a".to_string(),
1523            to_variable: "b".to_string(),
1524            edge_variable: None,
1525            direction: ExpandDirection::Outgoing,
1526            edge_types: vec!["KNOWS".to_string()],
1527            min_hops: 1,
1528            max_hops: Some(1),
1529            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1530                variable: "a".to_string(),
1531                label: Some("Person".to_string()),
1532                input: None,
1533            })),
1534            path_alias: None,
1535            path_mode: PathMode::Walk,
1536        });
1537
1538        let cardinality = estimator.estimate(&expand);
1539        // With edge type, fanout is reduced by half
1540        assert!(cardinality > 100.0);
1541    }
1542
1543    #[test]
1544    fn test_expand_variable_length() {
1545        let mut estimator = CardinalityEstimator::new();
1546        estimator.add_table_stats("Person", TableStats::new(100));
1547
1548        let expand = LogicalOperator::Expand(ExpandOp {
1549            from_variable: "a".to_string(),
1550            to_variable: "b".to_string(),
1551            edge_variable: None,
1552            direction: ExpandDirection::Outgoing,
1553            edge_types: vec![],
1554            min_hops: 1,
1555            max_hops: Some(3),
1556            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1557                variable: "a".to_string(),
1558                label: Some("Person".to_string()),
1559                input: None,
1560            })),
1561            path_alias: None,
1562            path_mode: PathMode::Walk,
1563        });
1564
1565        let cardinality = estimator.estimate(&expand);
1566        // Variable length path has much higher cardinality
1567        assert!(cardinality > 500.0);
1568    }
1569
1570    #[test]
1571    fn test_join_cross_product() {
1572        let mut estimator = CardinalityEstimator::new();
1573        estimator.add_table_stats("Person", TableStats::new(100));
1574        estimator.add_table_stats("Company", TableStats::new(50));
1575
1576        let join = LogicalOperator::Join(JoinOp {
1577            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1578                variable: "p".to_string(),
1579                label: Some("Person".to_string()),
1580                input: None,
1581            })),
1582            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1583                variable: "c".to_string(),
1584                label: Some("Company".to_string()),
1585                input: None,
1586            })),
1587            join_type: JoinType::Cross,
1588            conditions: vec![],
1589        });
1590
1591        let cardinality = estimator.estimate(&join);
1592        // Cross join = 100 * 50 = 5000
1593        assert!((cardinality - 5000.0).abs() < 0.001);
1594    }
1595
1596    #[test]
1597    fn test_join_left_outer() {
1598        let mut estimator = CardinalityEstimator::new();
1599        estimator.add_table_stats("Person", TableStats::new(1000));
1600        estimator.add_table_stats("Company", TableStats::new(10));
1601
1602        let join = LogicalOperator::Join(JoinOp {
1603            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1604                variable: "p".to_string(),
1605                label: Some("Person".to_string()),
1606                input: None,
1607            })),
1608            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1609                variable: "c".to_string(),
1610                label: Some("Company".to_string()),
1611                input: None,
1612            })),
1613            join_type: JoinType::Left,
1614            conditions: vec![JoinCondition {
1615                left: LogicalExpression::Variable("p".to_string()),
1616                right: LogicalExpression::Variable("c".to_string()),
1617            }],
1618        });
1619
1620        let cardinality = estimator.estimate(&join);
1621        // Left join returns at least all left rows
1622        assert!(cardinality >= 1000.0);
1623    }
1624
1625    #[test]
1626    fn test_join_semi() {
1627        let mut estimator = CardinalityEstimator::new();
1628        estimator.add_table_stats("Person", TableStats::new(1000));
1629        estimator.add_table_stats("Company", TableStats::new(100));
1630
1631        let join = LogicalOperator::Join(JoinOp {
1632            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1633                variable: "p".to_string(),
1634                label: Some("Person".to_string()),
1635                input: None,
1636            })),
1637            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1638                variable: "c".to_string(),
1639                label: Some("Company".to_string()),
1640                input: None,
1641            })),
1642            join_type: JoinType::Semi,
1643            conditions: vec![],
1644        });
1645
1646        let cardinality = estimator.estimate(&join);
1647        // Semi join returns at most left cardinality
1648        assert!(cardinality <= 1000.0);
1649    }
1650
1651    #[test]
1652    fn test_join_anti() {
1653        let mut estimator = CardinalityEstimator::new();
1654        estimator.add_table_stats("Person", TableStats::new(1000));
1655        estimator.add_table_stats("Company", TableStats::new(100));
1656
1657        let join = LogicalOperator::Join(JoinOp {
1658            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1659                variable: "p".to_string(),
1660                label: Some("Person".to_string()),
1661                input: None,
1662            })),
1663            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1664                variable: "c".to_string(),
1665                label: Some("Company".to_string()),
1666                input: None,
1667            })),
1668            join_type: JoinType::Anti,
1669            conditions: vec![],
1670        });
1671
1672        let cardinality = estimator.estimate(&join);
1673        // Anti join returns at most left cardinality
1674        assert!(cardinality <= 1000.0);
1675        assert!(cardinality >= 1.0);
1676    }
1677
1678    #[test]
1679    fn test_project_preserves_cardinality() {
1680        let mut estimator = CardinalityEstimator::new();
1681        estimator.add_table_stats("Person", TableStats::new(1000));
1682
1683        let project = LogicalOperator::Project(ProjectOp {
1684            projections: vec![Projection {
1685                expression: LogicalExpression::Variable("n".to_string()),
1686                alias: None,
1687            }],
1688            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1689                variable: "n".to_string(),
1690                label: Some("Person".to_string()),
1691                input: None,
1692            })),
1693            pass_through_input: false,
1694        });
1695
1696        let cardinality = estimator.estimate(&project);
1697        assert!((cardinality - 1000.0).abs() < 0.001);
1698    }
1699
1700    #[test]
1701    fn test_sort_preserves_cardinality() {
1702        let mut estimator = CardinalityEstimator::new();
1703        estimator.add_table_stats("Person", TableStats::new(1000));
1704
1705        let sort = LogicalOperator::Sort(SortOp {
1706            keys: vec![SortKey {
1707                expression: LogicalExpression::Variable("n".to_string()),
1708                order: SortOrder::Ascending,
1709                nulls: None,
1710            }],
1711            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1712                variable: "n".to_string(),
1713                label: Some("Person".to_string()),
1714                input: None,
1715            })),
1716        });
1717
1718        let cardinality = estimator.estimate(&sort);
1719        assert!((cardinality - 1000.0).abs() < 0.001);
1720    }
1721
1722    #[test]
1723    fn test_distinct_reduces_cardinality() {
1724        let mut estimator = CardinalityEstimator::new();
1725        estimator.add_table_stats("Person", TableStats::new(1000));
1726
1727        let distinct = LogicalOperator::Distinct(DistinctOp {
1728            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1729                variable: "n".to_string(),
1730                label: Some("Person".to_string()),
1731                input: None,
1732            })),
1733            columns: None,
1734        });
1735
1736        let cardinality = estimator.estimate(&distinct);
1737        // Distinct assumes 50% unique
1738        assert!((cardinality - 500.0).abs() < 0.001);
1739    }
1740
1741    #[test]
1742    fn test_skip_reduces_cardinality() {
1743        let mut estimator = CardinalityEstimator::new();
1744        estimator.add_table_stats("Person", TableStats::new(1000));
1745
1746        let skip = LogicalOperator::Skip(SkipOp {
1747            count: 100.into(),
1748            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1749                variable: "n".to_string(),
1750                label: Some("Person".to_string()),
1751                input: None,
1752            })),
1753        });
1754
1755        let cardinality = estimator.estimate(&skip);
1756        assert!((cardinality - 900.0).abs() < 0.001);
1757    }
1758
1759    #[test]
1760    fn test_return_preserves_cardinality() {
1761        let mut estimator = CardinalityEstimator::new();
1762        estimator.add_table_stats("Person", TableStats::new(1000));
1763
1764        let ret = LogicalOperator::Return(ReturnOp {
1765            items: vec![ReturnItem {
1766                expression: LogicalExpression::Variable("n".to_string()),
1767                alias: None,
1768            }],
1769            distinct: false,
1770            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1771                variable: "n".to_string(),
1772                label: Some("Person".to_string()),
1773                input: None,
1774            })),
1775        });
1776
1777        let cardinality = estimator.estimate(&ret);
1778        assert!((cardinality - 1000.0).abs() < 0.001);
1779    }
1780
1781    #[test]
1782    fn test_empty_cardinality() {
1783        let estimator = CardinalityEstimator::new();
1784        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1785        assert!((cardinality).abs() < 0.001);
1786    }
1787
1788    #[test]
1789    fn test_table_stats_with_column() {
1790        let stats = TableStats::new(1000).with_column(
1791            "age",
1792            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1793        );
1794
1795        assert_eq!(stats.row_count, 1000);
1796        let col = stats.columns.get("age").unwrap();
1797        assert_eq!(col.distinct_count, 50);
1798        assert_eq!(col.null_count, 10);
1799        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1800        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1801    }
1802
1803    #[test]
1804    fn test_estimator_default() {
1805        let estimator = CardinalityEstimator::default();
1806        let scan = LogicalOperator::NodeScan(NodeScanOp {
1807            variable: "n".to_string(),
1808            label: None,
1809            input: None,
1810        });
1811        let cardinality = estimator.estimate(&scan);
1812        assert!((cardinality - 1000.0).abs() < 0.001);
1813    }
1814
1815    #[test]
1816    fn test_set_avg_fanout() {
1817        let mut estimator = CardinalityEstimator::new();
1818        estimator.add_table_stats("Person", TableStats::new(100));
1819        estimator.set_avg_fanout(5.0);
1820
1821        let expand = LogicalOperator::Expand(ExpandOp {
1822            from_variable: "a".to_string(),
1823            to_variable: "b".to_string(),
1824            edge_variable: None,
1825            direction: ExpandDirection::Outgoing,
1826            edge_types: vec![],
1827            min_hops: 1,
1828            max_hops: Some(1),
1829            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1830                variable: "a".to_string(),
1831                label: Some("Person".to_string()),
1832                input: None,
1833            })),
1834            path_alias: None,
1835            path_mode: PathMode::Walk,
1836        });
1837
1838        let cardinality = estimator.estimate(&expand);
1839        // With fanout 5: 100 * 5 = 500
1840        assert!((cardinality - 500.0).abs() < 0.001);
1841    }
1842
1843    #[test]
1844    fn test_multiple_group_by_keys_reduce_cardinality() {
1845        // The current implementation uses a simplified model where more group by keys
1846        // results in greater reduction (dividing by 10^num_keys). This is a simplification
1847        // that works for most cases where group by keys are correlated.
1848        let mut estimator = CardinalityEstimator::new();
1849        estimator.add_table_stats("Person", TableStats::new(10000));
1850
1851        let single_group = LogicalOperator::Aggregate(AggregateOp {
1852            group_by: vec![LogicalExpression::Property {
1853                variable: "n".to_string(),
1854                property: "city".to_string(),
1855            }],
1856            aggregates: vec![],
1857            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1858                variable: "n".to_string(),
1859                label: Some("Person".to_string()),
1860                input: None,
1861            })),
1862            having: None,
1863        });
1864
1865        let multi_group = LogicalOperator::Aggregate(AggregateOp {
1866            group_by: vec![
1867                LogicalExpression::Property {
1868                    variable: "n".to_string(),
1869                    property: "city".to_string(),
1870                },
1871                LogicalExpression::Property {
1872                    variable: "n".to_string(),
1873                    property: "country".to_string(),
1874                },
1875            ],
1876            aggregates: vec![],
1877            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1878                variable: "n".to_string(),
1879                label: Some("Person".to_string()),
1880                input: None,
1881            })),
1882            having: None,
1883        });
1884
1885        let single_card = estimator.estimate(&single_group);
1886        let multi_card = estimator.estimate(&multi_group);
1887
1888        // Both should reduce cardinality from input
1889        assert!(single_card < 10000.0);
1890        assert!(multi_card < 10000.0);
1891        // Both should be at least 1
1892        assert!(single_card >= 1.0);
1893        assert!(multi_card >= 1.0);
1894    }
1895
1896    // ============= Histogram Tests =============
1897
1898    #[test]
1899    fn test_histogram_build_uniform() {
1900        // Build histogram from uniformly distributed data
1901        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1902        let histogram = EquiDepthHistogram::build(&values, 10);
1903
1904        assert_eq!(histogram.num_buckets(), 10);
1905        assert_eq!(histogram.total_rows(), 100);
1906
1907        // Each bucket should have approximately 10 rows
1908        for bucket in histogram.buckets() {
1909            assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1910        }
1911    }
1912
1913    #[test]
1914    fn test_histogram_build_skewed() {
1915        // Build histogram from skewed data (many small values, few large)
1916        let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1917        values.extend((0..20).map(|i| 1000.0 + i as f64));
1918        let histogram = EquiDepthHistogram::build(&values, 5);
1919
1920        assert_eq!(histogram.num_buckets(), 5);
1921        assert_eq!(histogram.total_rows(), 100);
1922
1923        // Each bucket should have ~20 rows despite skewed data
1924        for bucket in histogram.buckets() {
1925            assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1926        }
1927    }
1928
1929    #[test]
1930    fn test_histogram_range_selectivity_full() {
1931        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1932        let histogram = EquiDepthHistogram::build(&values, 10);
1933
1934        // Full range should have selectivity ~1.0
1935        let selectivity = histogram.range_selectivity(None, None);
1936        assert!((selectivity - 1.0).abs() < 0.01);
1937    }
1938
1939    #[test]
1940    fn test_histogram_range_selectivity_half() {
1941        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1942        let histogram = EquiDepthHistogram::build(&values, 10);
1943
1944        // Values >= 50 should be ~50% (half the data)
1945        let selectivity = histogram.range_selectivity(Some(50.0), None);
1946        assert!(selectivity > 0.4 && selectivity < 0.6);
1947    }
1948
1949    #[test]
1950    fn test_histogram_range_selectivity_quarter() {
1951        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1952        let histogram = EquiDepthHistogram::build(&values, 10);
1953
1954        // Values < 25 should be ~25%
1955        let selectivity = histogram.range_selectivity(None, Some(25.0));
1956        assert!(selectivity > 0.2 && selectivity < 0.3);
1957    }
1958
1959    #[test]
1960    fn test_histogram_equality_selectivity() {
1961        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1962        let histogram = EquiDepthHistogram::build(&values, 10);
1963
1964        // Equality on 100 distinct values should be ~1%
1965        let selectivity = histogram.equality_selectivity(50.0);
1966        assert!(selectivity > 0.005 && selectivity < 0.02);
1967    }
1968
1969    #[test]
1970    fn test_histogram_empty() {
1971        let histogram = EquiDepthHistogram::build(&[], 10);
1972
1973        assert_eq!(histogram.num_buckets(), 0);
1974        assert_eq!(histogram.total_rows(), 0);
1975
1976        // Default selectivity for empty histogram
1977        let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1978        assert!((selectivity - 0.33).abs() < 0.01);
1979    }
1980
1981    #[test]
1982    fn test_histogram_bucket_overlap() {
1983        let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1984
1985        // Full overlap
1986        assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1987
1988        // Half overlap (lower half)
1989        assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1990
1991        // Half overlap (upper half)
1992        assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1993
1994        // No overlap (below)
1995        assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1996
1997        // No overlap (above)
1998        assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1999    }
2000
2001    #[test]
2002    fn test_column_stats_from_values() {
2003        let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2004        let stats = ColumnStats::from_values(values, 4);
2005
2006        assert_eq!(stats.distinct_count, 5); // 10, 20, 30, 40, 50
2007        assert!(stats.min_value.is_some());
2008        assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2009        assert!(stats.max_value.is_some());
2010        assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2011        assert!(stats.histogram.is_some());
2012    }
2013
2014    #[test]
2015    fn test_column_stats_with_histogram_builder() {
2016        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2017        let histogram = EquiDepthHistogram::build(&values, 10);
2018
2019        let stats = ColumnStats::new(100)
2020            .with_range(0.0, 99.0)
2021            .with_histogram(histogram);
2022
2023        assert!(stats.histogram.is_some());
2024        assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2025    }
2026
2027    #[test]
2028    fn test_filter_with_histogram_stats() {
2029        let mut estimator = CardinalityEstimator::new();
2030
2031        // Create stats with histogram for age column
2032        let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2033        let histogram = EquiDepthHistogram::build(&age_values, 10);
2034        let age_stats = ColumnStats::new(62)
2035            .with_range(18.0, 79.0)
2036            .with_histogram(histogram);
2037
2038        estimator.add_table_stats(
2039            "Person",
2040            TableStats::new(1000).with_column("age", age_stats),
2041        );
2042
2043        // Filter: age > 50
2044        // Age range is 18-79, so >50 is about (79-50)/(79-18) = 29/61 ≈ 47.5%
2045        let filter = LogicalOperator::Filter(FilterOp {
2046            predicate: LogicalExpression::Binary {
2047                left: Box::new(LogicalExpression::Property {
2048                    variable: "n".to_string(),
2049                    property: "age".to_string(),
2050                }),
2051                op: BinaryOp::Gt,
2052                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2053            },
2054            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2055                variable: "n".to_string(),
2056                label: Some("Person".to_string()),
2057                input: None,
2058            })),
2059            pushdown_hint: None,
2060        });
2061
2062        let cardinality = estimator.estimate(&filter);
2063
2064        // With histogram, should get more accurate estimate than default 0.33
2065        // Expected: ~47.5% of 1000 = ~475
2066        assert!(cardinality > 300.0 && cardinality < 600.0);
2067    }
2068
2069    #[test]
2070    fn test_filter_equality_with_histogram() {
2071        let mut estimator = CardinalityEstimator::new();
2072
2073        // Create stats with histogram
2074        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2075        let histogram = EquiDepthHistogram::build(&values, 10);
2076        let stats = ColumnStats::new(100)
2077            .with_range(0.0, 99.0)
2078            .with_histogram(histogram);
2079
2080        estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2081
2082        // Filter: value = 50
2083        let filter = LogicalOperator::Filter(FilterOp {
2084            predicate: LogicalExpression::Binary {
2085                left: Box::new(LogicalExpression::Property {
2086                    variable: "d".to_string(),
2087                    property: "value".to_string(),
2088                }),
2089                op: BinaryOp::Eq,
2090                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2091            },
2092            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2093                variable: "d".to_string(),
2094                label: Some("Data".to_string()),
2095                input: None,
2096            })),
2097            pushdown_hint: None,
2098        });
2099
2100        let cardinality = estimator.estimate(&filter);
2101
2102        // With 100 distinct values, selectivity should be ~1%
2103        // 1000 * 0.01 = 10
2104        assert!((1.0..50.0).contains(&cardinality));
2105    }
2106
2107    #[test]
2108    fn test_histogram_min_max() {
2109        let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2110        let histogram = EquiDepthHistogram::build(&values, 2);
2111
2112        assert_eq!(histogram.min_value(), Some(5.0));
2113        // Max is the upper bound of the last bucket
2114        assert!(histogram.max_value().is_some());
2115    }
2116
2117    // ==================== SelectivityConfig Tests ====================
2118
2119    #[test]
2120    fn test_selectivity_config_defaults() {
2121        let config = SelectivityConfig::new();
2122        assert!((config.default - 0.1).abs() < f64::EPSILON);
2123        assert!((config.equality - 0.01).abs() < f64::EPSILON);
2124        assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2125        assert!((config.range - 0.33).abs() < f64::EPSILON);
2126        assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2127        assert!((config.membership - 0.1).abs() < f64::EPSILON);
2128        assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2129        assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2130        assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2131    }
2132
2133    #[test]
2134    fn test_custom_selectivity_config() {
2135        let config = SelectivityConfig {
2136            equality: 0.05,
2137            range: 0.25,
2138            ..SelectivityConfig::new()
2139        };
2140        let estimator = CardinalityEstimator::with_selectivity_config(config);
2141        assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2142        assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2143    }
2144
2145    #[test]
2146    fn test_custom_selectivity_affects_estimation() {
2147        // Default: equality = 0.01 → 1000 * 0.01 = 10
2148        let mut default_est = CardinalityEstimator::new();
2149        default_est.add_table_stats("Person", TableStats::new(1000));
2150
2151        let filter = LogicalOperator::Filter(FilterOp {
2152            predicate: LogicalExpression::Binary {
2153                left: Box::new(LogicalExpression::Property {
2154                    variable: "n".to_string(),
2155                    property: "name".to_string(),
2156                }),
2157                op: BinaryOp::Eq,
2158                right: Box::new(LogicalExpression::Literal(Value::String("Alix".into()))),
2159            },
2160            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2161                variable: "n".to_string(),
2162                label: Some("Person".to_string()),
2163                input: None,
2164            })),
2165            pushdown_hint: None,
2166        });
2167
2168        let default_card = default_est.estimate(&filter);
2169
2170        // Custom: equality = 0.2 → 1000 * 0.2 = 200
2171        let config = SelectivityConfig {
2172            equality: 0.2,
2173            ..SelectivityConfig::new()
2174        };
2175        let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2176        custom_est.add_table_stats("Person", TableStats::new(1000));
2177
2178        let custom_card = custom_est.estimate(&filter);
2179
2180        assert!(custom_card > default_card);
2181        assert!((custom_card - 200.0).abs() < 1.0);
2182    }
2183
2184    #[test]
2185    fn test_custom_range_selectivity() {
2186        let config = SelectivityConfig {
2187            range: 0.5,
2188            ..SelectivityConfig::new()
2189        };
2190        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2191        estimator.add_table_stats("Person", TableStats::new(1000));
2192
2193        let filter = LogicalOperator::Filter(FilterOp {
2194            predicate: LogicalExpression::Binary {
2195                left: Box::new(LogicalExpression::Property {
2196                    variable: "n".to_string(),
2197                    property: "age".to_string(),
2198                }),
2199                op: BinaryOp::Gt,
2200                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2201            },
2202            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2203                variable: "n".to_string(),
2204                label: Some("Person".to_string()),
2205                input: None,
2206            })),
2207            pushdown_hint: None,
2208        });
2209
2210        let cardinality = estimator.estimate(&filter);
2211        // 1000 * 0.5 = 500
2212        assert!((cardinality - 500.0).abs() < 1.0);
2213    }
2214
2215    #[test]
2216    fn test_custom_distinct_fraction() {
2217        let config = SelectivityConfig {
2218            distinct_fraction: 0.8,
2219            ..SelectivityConfig::new()
2220        };
2221        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2222        estimator.add_table_stats("Person", TableStats::new(1000));
2223
2224        let distinct = LogicalOperator::Distinct(DistinctOp {
2225            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2226                variable: "n".to_string(),
2227                label: Some("Person".to_string()),
2228                input: None,
2229            })),
2230            columns: None,
2231        });
2232
2233        let cardinality = estimator.estimate(&distinct);
2234        // 1000 * 0.8 = 800
2235        assert!((cardinality - 800.0).abs() < 1.0);
2236    }
2237
2238    // ==================== EstimationLog Tests ====================
2239
2240    #[test]
2241    fn test_estimation_log_basic() {
2242        let mut log = EstimationLog::new(10.0);
2243        log.record("NodeScan(Person)", 1000.0, 1200.0);
2244        log.record("Filter(age > 30)", 100.0, 90.0);
2245
2246        assert_eq!(log.entries().len(), 2);
2247        assert!(!log.should_replan()); // 1.2x and 0.9x are within 10x threshold
2248    }
2249
2250    #[test]
2251    fn test_estimation_log_triggers_replan() {
2252        let mut log = EstimationLog::new(10.0);
2253        log.record("NodeScan(Person)", 100.0, 5000.0); // 50x underestimate
2254
2255        assert!(log.should_replan());
2256    }
2257
2258    #[test]
2259    fn test_estimation_log_overestimate_triggers_replan() {
2260        let mut log = EstimationLog::new(5.0);
2261        log.record("Filter", 1000.0, 100.0); // 10x overestimate → ratio = 0.1
2262
2263        assert!(log.should_replan()); // 0.1 < 1/5 = 0.2
2264    }
2265
2266    #[test]
2267    fn test_estimation_entry_error_ratio() {
2268        let entry = EstimationEntry {
2269            operator: "test".into(),
2270            estimated: 100.0,
2271            actual: 200.0,
2272        };
2273        assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2274
2275        let perfect = EstimationEntry {
2276            operator: "test".into(),
2277            estimated: 100.0,
2278            actual: 100.0,
2279        };
2280        assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2281
2282        let zero_est = EstimationEntry {
2283            operator: "test".into(),
2284            estimated: 0.0,
2285            actual: 0.0,
2286        };
2287        assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2288    }
2289
2290    #[test]
2291    fn test_estimation_log_max_error_ratio() {
2292        let mut log = EstimationLog::new(10.0);
2293        log.record("A", 100.0, 300.0); // 3x
2294        log.record("B", 100.0, 50.0); // 2x (normalized: 1/0.5 = 2)
2295        log.record("C", 100.0, 100.0); // 1x
2296
2297        assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2298    }
2299
2300    #[test]
2301    fn test_estimation_log_clear() {
2302        let mut log = EstimationLog::new(10.0);
2303        log.record("A", 100.0, 100.0);
2304        assert_eq!(log.entries().len(), 1);
2305
2306        log.clear();
2307        assert!(log.entries().is_empty());
2308        assert!(!log.should_replan());
2309    }
2310
2311    #[test]
2312    fn test_create_estimation_log() {
2313        let log = CardinalityEstimator::create_estimation_log();
2314        assert!(log.entries().is_empty());
2315        assert!(!log.should_replan());
2316    }
2317}