Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4//!
5//! # Equi-Depth Histograms
6//!
7//! This module provides equi-depth histogram support for accurate selectivity
8//! estimation of range predicates. Unlike equi-width histograms that divide
9//! the value range into equal-sized buckets, equi-depth histograms divide
10//! the data into buckets with approximately equal numbers of rows.
11//!
12//! Benefits:
13//! - Better estimates for skewed data distributions
14//! - More accurate range selectivity than assuming uniform distribution
15//! - Adaptive to actual data characteristics
16
17use crate::query::plan::{
18    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp, LimitOp,
19    LogicalExpression, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, SkipOp, SortOp,
20    TripleComponent, TripleScanOp, UnaryOp, VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24/// A bucket in an equi-depth histogram.
25///
26/// Each bucket represents a range of values and the frequency of rows
27/// falling within that range. In an equi-depth histogram, all buckets
28/// contain approximately the same number of rows.
29#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31    /// Lower bound of the bucket (inclusive).
32    pub lower_bound: f64,
33    /// Upper bound of the bucket (exclusive, except for the last bucket).
34    pub upper_bound: f64,
35    /// Number of rows in this bucket.
36    pub frequency: u64,
37    /// Number of distinct values in this bucket.
38    pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42    /// Creates a new histogram bucket.
43    #[must_use]
44    pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45        Self {
46            lower_bound,
47            upper_bound,
48            frequency,
49            distinct_count,
50        }
51    }
52
53    /// Returns the width of this bucket.
54    #[must_use]
55    pub fn width(&self) -> f64 {
56        self.upper_bound - self.lower_bound
57    }
58
59    /// Checks if a value falls within this bucket.
60    #[must_use]
61    pub fn contains(&self, value: f64) -> bool {
62        value >= self.lower_bound && value < self.upper_bound
63    }
64
65    /// Returns the fraction of this bucket covered by the given range.
66    #[must_use]
67    pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68        let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69        let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71        let bucket_width = self.width();
72        if bucket_width <= 0.0 {
73            return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74                1.0
75            } else {
76                0.0
77            };
78        }
79
80        let overlap = (effective_upper - effective_lower).max(0.0);
81        (overlap / bucket_width).min(1.0)
82    }
83}
84
85/// An equi-depth histogram for selectivity estimation.
86///
87/// Equi-depth histograms partition data into buckets where each bucket
88/// contains approximately the same number of rows. This provides more
89/// accurate selectivity estimates than assuming uniform distribution,
90/// especially for skewed data.
91///
92/// # Example
93///
94/// ```no_run
95/// use grafeo_engine::query::optimizer::cardinality::EquiDepthHistogram;
96///
97/// // Build a histogram from sorted values
98/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0];
99/// let histogram = EquiDepthHistogram::build(&values, 4);
100///
101/// // Estimate selectivity for age > 25
102/// let selectivity = histogram.range_selectivity(Some(25.0), None);
103/// ```
104#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106    /// The histogram buckets, sorted by lower_bound.
107    buckets: Vec<HistogramBucket>,
108    /// Total number of rows represented by this histogram.
109    total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113    /// Creates a new histogram from pre-built buckets.
114    #[must_use]
115    pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116        let total_rows = buckets.iter().map(|b| b.frequency).sum();
117        Self {
118            buckets,
119            total_rows,
120        }
121    }
122
123    /// Builds an equi-depth histogram from a sorted slice of values.
124    ///
125    /// # Arguments
126    /// * `values` - A sorted slice of numeric values
127    /// * `num_buckets` - The desired number of buckets
128    ///
129    /// # Returns
130    /// An equi-depth histogram with approximately equal row counts per bucket.
131    #[must_use]
132    pub fn build(values: &[f64], num_buckets: usize) -> Self {
133        if values.is_empty() || num_buckets == 0 {
134            return Self {
135                buckets: Vec::new(),
136                total_rows: 0,
137            };
138        }
139
140        let num_buckets = num_buckets.min(values.len());
141        let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142        let mut buckets = Vec::with_capacity(num_buckets);
143
144        let mut start_idx = 0;
145        while start_idx < values.len() {
146            let end_idx = (start_idx + rows_per_bucket).min(values.len());
147            let lower_bound = values[start_idx];
148            let upper_bound = if end_idx < values.len() {
149                values[end_idx]
150            } else {
151                // For the last bucket, extend slightly beyond the max value
152                values[end_idx - 1] + 1.0
153            };
154
155            // Count distinct values in this bucket
156            let bucket_values = &values[start_idx..end_idx];
157            let distinct_count = count_distinct(bucket_values);
158
159            buckets.push(HistogramBucket::new(
160                lower_bound,
161                upper_bound,
162                (end_idx - start_idx) as u64,
163                distinct_count,
164            ));
165
166            start_idx = end_idx;
167        }
168
169        Self::new(buckets)
170    }
171
172    /// Returns the number of buckets in this histogram.
173    #[must_use]
174    pub fn num_buckets(&self) -> usize {
175        self.buckets.len()
176    }
177
178    /// Returns the total number of rows represented.
179    #[must_use]
180    pub fn total_rows(&self) -> u64 {
181        self.total_rows
182    }
183
184    /// Returns the histogram buckets.
185    #[must_use]
186    pub fn buckets(&self) -> &[HistogramBucket] {
187        &self.buckets
188    }
189
190    /// Estimates selectivity for a range predicate.
191    ///
192    /// # Arguments
193    /// * `lower` - Lower bound (None for unbounded)
194    /// * `upper` - Upper bound (None for unbounded)
195    ///
196    /// # Returns
197    /// Estimated fraction of rows matching the range (0.0 to 1.0).
198    #[must_use]
199    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200        if self.buckets.is_empty() || self.total_rows == 0 {
201            return 0.33; // Default fallback
202        }
203
204        let mut matching_rows = 0.0;
205
206        for bucket in &self.buckets {
207            // Check if this bucket overlaps with the range
208            let bucket_lower = bucket.lower_bound;
209            let bucket_upper = bucket.upper_bound;
210
211            // Skip buckets entirely outside the range
212            if let Some(l) = lower
213                && bucket_upper <= l
214            {
215                continue;
216            }
217            if let Some(u) = upper
218                && bucket_lower >= u
219            {
220                continue;
221            }
222
223            // Calculate the fraction of this bucket covered by the range
224            let overlap = bucket.overlap_fraction(lower, upper);
225            matching_rows += overlap * bucket.frequency as f64;
226        }
227
228        (matching_rows / self.total_rows as f64).clamp(0.0, 1.0)
229    }
230
231    /// Estimates selectivity for an equality predicate.
232    ///
233    /// Uses the distinct count within matching buckets for better accuracy.
234    #[must_use]
235    pub fn equality_selectivity(&self, value: f64) -> f64 {
236        if self.buckets.is_empty() || self.total_rows == 0 {
237            return 0.01; // Default fallback
238        }
239
240        // Find the bucket containing this value
241        for bucket in &self.buckets {
242            if bucket.contains(value) {
243                // Assume uniform distribution within bucket
244                if bucket.distinct_count > 0 {
245                    return (bucket.frequency as f64
246                        / bucket.distinct_count as f64
247                        / self.total_rows as f64)
248                        .min(1.0);
249                }
250            }
251        }
252
253        // Value not in any bucket - very low selectivity
254        0.001
255    }
256
257    /// Gets the minimum value in the histogram.
258    #[must_use]
259    pub fn min_value(&self) -> Option<f64> {
260        self.buckets.first().map(|b| b.lower_bound)
261    }
262
263    /// Gets the maximum value in the histogram.
264    #[must_use]
265    pub fn max_value(&self) -> Option<f64> {
266        self.buckets.last().map(|b| b.upper_bound)
267    }
268}
269
270/// Counts distinct values in a sorted slice.
271fn count_distinct(sorted_values: &[f64]) -> u64 {
272    if sorted_values.is_empty() {
273        return 0;
274    }
275
276    let mut count = 1u64;
277    let mut prev = sorted_values[0];
278
279    for &val in &sorted_values[1..] {
280        if (val - prev).abs() > f64::EPSILON {
281            count += 1;
282            prev = val;
283        }
284    }
285
286    count
287}
288
289/// Statistics for a table/label.
290#[derive(Debug, Clone)]
291pub struct TableStats {
292    /// Total number of rows.
293    pub row_count: u64,
294    /// Column statistics.
295    pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299    /// Creates new table statistics.
300    #[must_use]
301    pub fn new(row_count: u64) -> Self {
302        Self {
303            row_count,
304            columns: HashMap::new(),
305        }
306    }
307
308    /// Adds column statistics.
309    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310        self.columns.insert(name.to_string(), stats);
311        self
312    }
313}
314
315/// Statistics for a column.
316#[derive(Debug, Clone)]
317pub struct ColumnStats {
318    /// Number of distinct values.
319    pub distinct_count: u64,
320    /// Number of null values.
321    pub null_count: u64,
322    /// Minimum value (if orderable).
323    pub min_value: Option<f64>,
324    /// Maximum value (if orderable).
325    pub max_value: Option<f64>,
326    /// Equi-depth histogram for accurate selectivity estimation.
327    pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331    /// Creates new column statistics.
332    #[must_use]
333    pub fn new(distinct_count: u64) -> Self {
334        Self {
335            distinct_count,
336            null_count: 0,
337            min_value: None,
338            max_value: None,
339            histogram: None,
340        }
341    }
342
343    /// Sets the null count.
344    #[must_use]
345    pub fn with_nulls(mut self, null_count: u64) -> Self {
346        self.null_count = null_count;
347        self
348    }
349
350    /// Sets the min/max range.
351    #[must_use]
352    pub fn with_range(mut self, min: f64, max: f64) -> Self {
353        self.min_value = Some(min);
354        self.max_value = Some(max);
355        self
356    }
357
358    /// Sets the equi-depth histogram for this column.
359    #[must_use]
360    pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361        self.histogram = Some(histogram);
362        self
363    }
364
365    /// Builds column statistics with histogram from raw values.
366    ///
367    /// This is a convenience method that computes all statistics from the data.
368    ///
369    /// # Arguments
370    /// * `values` - The column values (will be sorted internally)
371    /// * `num_buckets` - Number of histogram buckets to create
372    #[must_use]
373    pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374        if values.is_empty() {
375            return Self::new(0);
376        }
377
378        // Sort values for histogram building
379        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381        let min = values.first().copied();
382        let max = values.last().copied();
383        let distinct_count = count_distinct(&values);
384        let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386        Self {
387            distinct_count,
388            null_count: 0,
389            min_value: min,
390            max_value: max,
391            histogram: Some(histogram),
392        }
393    }
394}
395
396/// Configurable selectivity defaults for cardinality estimation.
397///
398/// Controls the assumed selectivity for various predicate types when
399/// histogram or column statistics are unavailable. Adjusting these
400/// values can improve plan quality for workloads with known skew.
401#[derive(Debug, Clone)]
402pub struct SelectivityConfig {
403    /// Selectivity for unknown predicates (default: 0.1).
404    pub default: f64,
405    /// Selectivity for equality predicates without stats (default: 0.01).
406    pub equality: f64,
407    /// Selectivity for inequality predicates (default: 0.99).
408    pub inequality: f64,
409    /// Selectivity for range predicates without stats (default: 0.33).
410    pub range: f64,
411    /// Selectivity for string operations: STARTS WITH, ENDS WITH, CONTAINS, LIKE (default: 0.1).
412    pub string_ops: f64,
413    /// Selectivity for IN membership (default: 0.1).
414    pub membership: f64,
415    /// Selectivity for IS NULL (default: 0.05).
416    pub is_null: f64,
417    /// Selectivity for IS NOT NULL (default: 0.95).
418    pub is_not_null: f64,
419    /// Fraction assumed distinct for DISTINCT operations (default: 0.5).
420    pub distinct_fraction: f64,
421}
422
423impl SelectivityConfig {
424    /// Creates a new config with standard database defaults.
425    #[must_use]
426    pub fn new() -> Self {
427        Self {
428            default: 0.1,
429            equality: 0.01,
430            inequality: 0.99,
431            range: 0.33,
432            string_ops: 0.1,
433            membership: 0.1,
434            is_null: 0.05,
435            is_not_null: 0.95,
436            distinct_fraction: 0.5,
437        }
438    }
439}
440
441impl Default for SelectivityConfig {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447/// A single estimate-vs-actual observation for analysis.
448#[derive(Debug, Clone)]
449pub struct EstimationEntry {
450    /// Human-readable label for the operator (e.g., "NodeScan(Person)").
451    pub operator: String,
452    /// The cardinality estimate produced by the optimizer.
453    pub estimated: f64,
454    /// The actual row count observed at execution time.
455    pub actual: f64,
456}
457
458impl EstimationEntry {
459    /// Returns the estimation error ratio (actual / estimated).
460    ///
461    /// Values near 1.0 indicate accurate estimates.
462    /// Values > 1.0 indicate underestimation.
463    /// Values < 1.0 indicate overestimation.
464    #[must_use]
465    pub fn error_ratio(&self) -> f64 {
466        if self.estimated.abs() < f64::EPSILON {
467            if self.actual.abs() < f64::EPSILON {
468                1.0
469            } else {
470                f64::INFINITY
471            }
472        } else {
473            self.actual / self.estimated
474        }
475    }
476}
477
478/// Collects estimate vs actual cardinality data for query plan analysis.
479///
480/// After executing a query, call [`record()`](Self::record) for each
481/// operator with its estimated and actual cardinalities. Then use
482/// [`should_replan()`](Self::should_replan) to decide whether the plan
483/// should be re-optimized.
484#[derive(Debug, Clone, Default)]
485pub struct EstimationLog {
486    /// Recorded entries.
487    entries: Vec<EstimationEntry>,
488    /// Error ratio threshold that triggers re-planning (default: 10.0).
489    ///
490    /// If any operator's error ratio exceeds this, `should_replan()` returns true.
491    replan_threshold: f64,
492}
493
494impl EstimationLog {
495    /// Creates a new estimation log with the given re-planning threshold.
496    #[must_use]
497    pub fn new(replan_threshold: f64) -> Self {
498        Self {
499            entries: Vec::new(),
500            replan_threshold,
501        }
502    }
503
504    /// Records an estimate-vs-actual observation.
505    pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
506        self.entries.push(EstimationEntry {
507            operator: operator.into(),
508            estimated,
509            actual,
510        });
511    }
512
513    /// Returns all recorded entries.
514    #[must_use]
515    pub fn entries(&self) -> &[EstimationEntry] {
516        &self.entries
517    }
518
519    /// Returns whether any operator's estimation error exceeds the threshold,
520    /// indicating the plan should be re-optimized.
521    #[must_use]
522    pub fn should_replan(&self) -> bool {
523        self.entries.iter().any(|e| {
524            let ratio = e.error_ratio();
525            ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
526        })
527    }
528
529    /// Returns the maximum error ratio across all entries.
530    #[must_use]
531    pub fn max_error_ratio(&self) -> f64 {
532        self.entries
533            .iter()
534            .map(|e| {
535                let r = e.error_ratio();
536                // Normalize so both over- and under-estimation are > 1.0
537                if r < 1.0 { 1.0 / r } else { r }
538            })
539            .fold(1.0_f64, f64::max)
540    }
541
542    /// Clears all entries.
543    pub fn clear(&mut self) {
544        self.entries.clear();
545    }
546}
547
548/// Cardinality estimator.
549pub struct CardinalityEstimator {
550    /// Statistics for each label/table.
551    table_stats: HashMap<String, TableStats>,
552    /// Default row count for unknown tables.
553    default_row_count: u64,
554    /// Default selectivity for unknown predicates.
555    default_selectivity: f64,
556    /// Average edge fanout (outgoing edges per node).
557    avg_fanout: f64,
558    /// Configurable selectivity defaults.
559    selectivity_config: SelectivityConfig,
560    /// RDF statistics for triple pattern cardinality estimation.
561    rdf_statistics: Option<grafeo_core::statistics::RdfStatistics>,
562}
563
564impl CardinalityEstimator {
565    /// Creates a new cardinality estimator with default settings.
566    #[must_use]
567    pub fn new() -> Self {
568        let config = SelectivityConfig::new();
569        Self {
570            table_stats: HashMap::new(),
571            default_row_count: 1000,
572            default_selectivity: config.default,
573            avg_fanout: 10.0,
574            selectivity_config: config,
575            rdf_statistics: None,
576        }
577    }
578
579    /// Creates a new cardinality estimator with custom selectivity configuration.
580    #[must_use]
581    pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
582        Self {
583            table_stats: HashMap::new(),
584            default_row_count: 1000,
585            default_selectivity: config.default,
586            avg_fanout: 10.0,
587            selectivity_config: config,
588            rdf_statistics: None,
589        }
590    }
591
592    /// Returns the current selectivity configuration.
593    #[must_use]
594    pub fn selectivity_config(&self) -> &SelectivityConfig {
595        &self.selectivity_config
596    }
597
598    /// Creates an estimation log with the default re-planning threshold (10x).
599    #[must_use]
600    pub fn create_estimation_log() -> EstimationLog {
601        EstimationLog::new(10.0)
602    }
603
604    /// Creates a cardinality estimator pre-populated from store statistics.
605    ///
606    /// Maps `LabelStatistics` to `TableStats` and computes the average edge
607    /// fanout from `EdgeTypeStatistics`. Falls back to defaults for any
608    /// missing statistics.
609    #[must_use]
610    pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
611        let mut estimator = Self::new();
612
613        // Use total node count as default for unlabeled scans
614        if stats.total_nodes > 0 {
615            estimator.default_row_count = stats.total_nodes;
616        }
617
618        // Convert label statistics to optimizer table stats
619        for (label, label_stats) in &stats.labels {
620            let mut table_stats = TableStats::new(label_stats.node_count);
621
622            // Map property statistics (distinct count for selectivity estimation)
623            for (prop, col_stats) in &label_stats.properties {
624                let optimizer_col =
625                    ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
626                table_stats = table_stats.with_column(prop, optimizer_col);
627            }
628
629            estimator.add_table_stats(label, table_stats);
630        }
631
632        // Compute average fanout from edge type statistics
633        if !stats.edge_types.is_empty() {
634            let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
635            estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
636        } else if stats.total_nodes > 0 {
637            estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
638        }
639
640        // Clamp fanout to a reasonable minimum
641        if estimator.avg_fanout < 1.0 {
642            estimator.avg_fanout = 1.0;
643        }
644
645        estimator
646    }
647
648    /// Creates a cardinality estimator from RDF statistics.
649    ///
650    /// Uses triple pattern cardinality estimates for `TripleScan` operators
651    /// and join selectivity from per-predicate statistics.
652    #[must_use]
653    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
654        let mut estimator = Self::new();
655        if rdf_stats.total_triples > 0 {
656            estimator.default_row_count = rdf_stats.total_triples;
657        }
658        estimator.rdf_statistics = Some(rdf_stats);
659        estimator
660    }
661
662    /// Adds statistics for a table/label.
663    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
664        self.table_stats.insert(name.to_string(), stats);
665    }
666
667    /// Sets the average edge fanout.
668    pub fn set_avg_fanout(&mut self, fanout: f64) {
669        self.avg_fanout = fanout;
670    }
671
672    /// Estimates the cardinality of a logical operator.
673    #[must_use]
674    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
675        match op {
676            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
677            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
678            LogicalOperator::Project(project) => self.estimate_project(project),
679            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
680            LogicalOperator::Join(join) => self.estimate_join(join),
681            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
682            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
683            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
684            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
685            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
686            LogicalOperator::Return(ret) => self.estimate(&ret.input),
687            LogicalOperator::Empty => 0.0,
688            LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
689            LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
690            LogicalOperator::MultiWayJoin(mwj) => self.estimate_multi_way_join(mwj),
691            LogicalOperator::LeftJoin(lj) => self.estimate_left_join(lj),
692            LogicalOperator::TripleScan(scan) => self.estimate_triple_scan(scan),
693            _ => self.default_row_count as f64,
694        }
695    }
696
697    /// Estimates node scan cardinality.
698    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
699        if let Some(label) = &scan.label
700            && let Some(stats) = self.table_stats.get(label)
701        {
702            return stats.row_count as f64;
703        }
704        // No label filter - scan all nodes
705        self.default_row_count as f64
706    }
707
708    /// Estimates triple scan cardinality using RDF statistics.
709    ///
710    /// If RDF statistics are available, uses the pattern binding (which positions
711    /// are bound vs variable) to produce accurate estimates. Otherwise falls back
712    /// to the default row count.
713    fn estimate_triple_scan(&self, scan: &TripleScanOp) -> f64 {
714        // If there's an input, the triple scan is chained: multiply input cardinality
715        // by the per-row expansion factor.
716        let base = if let Some(ref input) = scan.input {
717            self.estimate(input)
718        } else {
719            1.0
720        };
721
722        let Some(rdf_stats) = &self.rdf_statistics else {
723            return if scan.input.is_some() {
724                base * self.default_row_count as f64
725            } else {
726                self.default_row_count as f64
727            };
728        };
729
730        let subject_bound = matches!(
731            scan.subject,
732            TripleComponent::Iri(_) | TripleComponent::Literal(_)
733        );
734        let object_bound = matches!(
735            scan.object,
736            TripleComponent::Iri(_) | TripleComponent::Literal(_)
737        );
738        let predicate_iri = match &scan.predicate {
739            TripleComponent::Iri(iri) => Some(iri.as_str()),
740            _ => None,
741        };
742
743        let pattern_card = rdf_stats.estimate_triple_pattern_cardinality(
744            subject_bound,
745            predicate_iri,
746            object_bound,
747        );
748
749        if scan.input.is_some() {
750            // Chained scan: each input row expands by the pattern's selectivity
751            let selectivity = if rdf_stats.total_triples > 0 {
752                pattern_card / rdf_stats.total_triples as f64
753            } else {
754                1.0
755            };
756            (base * pattern_card * selectivity).max(1.0)
757        } else {
758            pattern_card.max(1.0)
759        }
760    }
761
762    /// Estimates filter cardinality.
763    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
764        let input_cardinality = self.estimate(&filter.input);
765        let selectivity = self.estimate_selectivity(&filter.predicate);
766        (input_cardinality * selectivity).max(1.0)
767    }
768
769    /// Estimates projection cardinality (same as input).
770    fn estimate_project(&self, project: &ProjectOp) -> f64 {
771        self.estimate(&project.input)
772    }
773
774    /// Estimates expand cardinality.
775    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
776        let input_cardinality = self.estimate(&expand.input);
777
778        // Apply fanout based on edge type
779        let fanout = if !expand.edge_types.is_empty() {
780            // Specific edge type(s) typically have lower fanout
781            self.avg_fanout * 0.5
782        } else {
783            self.avg_fanout
784        };
785
786        // Handle variable-length paths
787        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
788            let min = expand.min_hops as f64;
789            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
790            // Geometric series approximation
791            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
792        } else {
793            fanout
794        };
795
796        (input_cardinality * path_multiplier).max(1.0)
797    }
798
799    /// Estimates join cardinality.
800    fn estimate_join(&self, join: &JoinOp) -> f64 {
801        let left_card = self.estimate(&join.left);
802        let right_card = self.estimate(&join.right);
803
804        match join.join_type {
805            JoinType::Cross => left_card * right_card,
806            JoinType::Inner => {
807                // Assume join selectivity based on conditions
808                let selectivity = if join.conditions.is_empty() {
809                    1.0 // Cross join
810                } else {
811                    // Estimate based on number of conditions
812                    0.1_f64.powi(join.conditions.len() as i32)
813                };
814                (left_card * right_card * selectivity).max(1.0)
815            }
816            JoinType::Left => {
817                // Left join returns at least all left rows
818                let inner_card = self.estimate_join(&JoinOp {
819                    left: join.left.clone(),
820                    right: join.right.clone(),
821                    join_type: JoinType::Inner,
822                    conditions: join.conditions.clone(),
823                });
824                inner_card.max(left_card)
825            }
826            JoinType::Right => {
827                // Right join returns at least all right rows
828                let inner_card = self.estimate_join(&JoinOp {
829                    left: join.left.clone(),
830                    right: join.right.clone(),
831                    join_type: JoinType::Inner,
832                    conditions: join.conditions.clone(),
833                });
834                inner_card.max(right_card)
835            }
836            JoinType::Full => {
837                // Full join returns at least max(left, right)
838                let inner_card = self.estimate_join(&JoinOp {
839                    left: join.left.clone(),
840                    right: join.right.clone(),
841                    join_type: JoinType::Inner,
842                    conditions: join.conditions.clone(),
843                });
844                inner_card.max(left_card.max(right_card))
845            }
846            JoinType::Semi => {
847                // Semi join returns at most left cardinality
848                (left_card * self.default_selectivity).max(1.0)
849            }
850            JoinType::Anti => {
851                // Anti join returns at most left cardinality
852                (left_card * (1.0 - self.default_selectivity)).max(1.0)
853            }
854        }
855    }
856
857    /// Estimates left join cardinality (OPTIONAL MATCH).
858    ///
859    /// A left outer join preserves all left rows, so the output is at least
860    /// `left_cardinality`. When the right side matches, the output may be
861    /// larger (one left row can match multiple right rows).
862    fn estimate_left_join(&self, lj: &LeftJoinOp) -> f64 {
863        let left_card = self.estimate(&lj.left);
864        let right_card = self.estimate(&lj.right);
865
866        // Estimate as inner join cardinality, but guaranteed at least left_card
867        let inner_estimate = left_card * right_card * self.default_selectivity;
868        inner_estimate.max(left_card).max(1.0)
869    }
870
871    /// Estimates aggregation cardinality.
872    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
873        let input_cardinality = self.estimate(&agg.input);
874
875        if agg.group_by.is_empty() {
876            // Global aggregation - single row
877            1.0
878        } else {
879            // Group by - estimate distinct groups
880            // Assume each group key reduces cardinality by 10
881            let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
882            (input_cardinality / group_reduction).max(1.0)
883        }
884    }
885
886    /// Estimates sort cardinality (same as input).
887    fn estimate_sort(&self, sort: &SortOp) -> f64 {
888        self.estimate(&sort.input)
889    }
890
891    /// Estimates distinct cardinality.
892    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
893        let input_cardinality = self.estimate(&distinct.input);
894        (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
895    }
896
897    /// Estimates limit cardinality.
898    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
899        let input_cardinality = self.estimate(&limit.input);
900        limit.count.estimate().min(input_cardinality)
901    }
902
903    /// Estimates skip cardinality.
904    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
905        let input_cardinality = self.estimate(&skip.input);
906        (input_cardinality - skip.count.estimate()).max(0.0)
907    }
908
909    /// Estimates vector scan cardinality.
910    ///
911    /// Vector scan returns at most k results (the k nearest neighbors).
912    /// With similarity/distance filters, it may return fewer.
913    fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
914        let base_k = scan.k as f64;
915
916        // Apply filter selectivity if thresholds are specified
917        let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
918            // Assume 70% of results pass threshold filters
919            0.7
920        } else {
921            1.0
922        };
923
924        (base_k * selectivity).max(1.0)
925    }
926
927    /// Estimates vector join cardinality.
928    ///
929    /// Vector join produces up to k results per input row.
930    fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
931        let input_cardinality = self.estimate(&join.input);
932        let k = join.k as f64;
933
934        // Apply filter selectivity if thresholds are specified
935        let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
936            0.7
937        } else {
938            1.0
939        };
940
941        (input_cardinality * k * selectivity).max(1.0)
942    }
943
944    /// Estimates multi-way join cardinality using the AGM bound heuristic.
945    ///
946    /// For a cyclic join of N relations, the AGM (Atserias-Grohe-Marx) bound
947    /// gives min(cardinality)^(N/2) as a worst-case output size estimate.
948    fn estimate_multi_way_join(&self, mwj: &MultiWayJoinOp) -> f64 {
949        if mwj.inputs.is_empty() {
950            return 0.0;
951        }
952        let cardinalities: Vec<f64> = mwj
953            .inputs
954            .iter()
955            .map(|input| self.estimate(input))
956            .collect();
957        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
958        let n = cardinalities.len() as f64;
959        // AGM bound: min(cardinality)^(n/2)
960        (min_card.powf(n / 2.0)).max(1.0)
961    }
962
963    /// Estimates the selectivity of a predicate (0.0 to 1.0).
964    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
965        match expr {
966            LogicalExpression::Binary { left, op, right } => {
967                self.estimate_binary_selectivity(left, *op, right)
968            }
969            LogicalExpression::Unary { op, operand } => {
970                self.estimate_unary_selectivity(*op, operand)
971            }
972            LogicalExpression::Literal(value) => {
973                // Boolean literal
974                if let grafeo_common::types::Value::Bool(b) = value {
975                    if *b { 1.0 } else { 0.0 }
976                } else {
977                    self.default_selectivity
978                }
979            }
980            _ => self.default_selectivity,
981        }
982    }
983
984    /// Estimates binary expression selectivity.
985    fn estimate_binary_selectivity(
986        &self,
987        left: &LogicalExpression,
988        op: BinaryOp,
989        right: &LogicalExpression,
990    ) -> f64 {
991        match op {
992            // Equality - try histogram-based estimation
993            BinaryOp::Eq => {
994                if let Some(selectivity) = self.try_equality_selectivity(left, right) {
995                    return selectivity;
996                }
997                self.selectivity_config.equality
998            }
999            // Inequality is very unselective
1000            BinaryOp::Ne => self.selectivity_config.inequality,
1001            // Range predicates - use histogram if available
1002            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
1003                if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
1004                    return selectivity;
1005                }
1006                self.selectivity_config.range
1007            }
1008            // Logical operators - recursively estimate sub-expressions
1009            BinaryOp::And => {
1010                let left_sel = self.estimate_selectivity(left);
1011                let right_sel = self.estimate_selectivity(right);
1012                // AND reduces selectivity (multiply assuming independence)
1013                left_sel * right_sel
1014            }
1015            BinaryOp::Or => {
1016                let left_sel = self.estimate_selectivity(left);
1017                let right_sel = self.estimate_selectivity(right);
1018                // OR: P(A ∪ B) = P(A) + P(B) - P(A ∩ B)
1019                // Assuming independence: P(A ∩ B) = P(A) * P(B)
1020                (left_sel + right_sel - left_sel * right_sel).min(1.0)
1021            }
1022            // String operations
1023            BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
1024                self.selectivity_config.string_ops
1025            }
1026            // Collection membership
1027            BinaryOp::In => self.selectivity_config.membership,
1028            // Other operations
1029            _ => self.default_selectivity,
1030        }
1031    }
1032
1033    /// Tries to estimate equality selectivity using histograms.
1034    fn try_equality_selectivity(
1035        &self,
1036        left: &LogicalExpression,
1037        right: &LogicalExpression,
1038    ) -> Option<f64> {
1039        // Extract property access and literal value
1040        let (label, column, value) = self.extract_column_and_value(left, right)?;
1041
1042        // Get column stats with histogram
1043        let stats = self.get_column_stats(&label, &column)?;
1044
1045        // Try histogram-based estimation
1046        if let Some(ref histogram) = stats.histogram {
1047            return Some(histogram.equality_selectivity(value));
1048        }
1049
1050        // Fall back to distinct count estimation
1051        if stats.distinct_count > 0 {
1052            return Some(1.0 / stats.distinct_count as f64);
1053        }
1054
1055        None
1056    }
1057
1058    /// Tries to estimate range selectivity using histograms.
1059    fn try_range_selectivity(
1060        &self,
1061        left: &LogicalExpression,
1062        op: BinaryOp,
1063        right: &LogicalExpression,
1064    ) -> Option<f64> {
1065        // Extract property access and literal value
1066        let (label, column, value) = self.extract_column_and_value(left, right)?;
1067
1068        // Get column stats
1069        let stats = self.get_column_stats(&label, &column)?;
1070
1071        // Determine the range based on operator
1072        let (lower, upper) = match op {
1073            BinaryOp::Lt => (None, Some(value)),
1074            BinaryOp::Le => (None, Some(value + f64::EPSILON)),
1075            BinaryOp::Gt => (Some(value + f64::EPSILON), None),
1076            BinaryOp::Ge => (Some(value), None),
1077            _ => return None,
1078        };
1079
1080        // Try histogram-based estimation first
1081        if let Some(ref histogram) = stats.histogram {
1082            return Some(histogram.range_selectivity(lower, upper));
1083        }
1084
1085        // Fall back to min/max range estimation
1086        if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
1087            let range = max - min;
1088            if range <= 0.0 {
1089                return Some(1.0);
1090            }
1091
1092            let effective_lower = lower.unwrap_or(min).max(min);
1093            let effective_upper = upper.unwrap_or(max).min(max);
1094            let overlap = (effective_upper - effective_lower).max(0.0);
1095            return Some((overlap / range).clamp(0.0, 1.0));
1096        }
1097
1098        None
1099    }
1100
1101    /// Extracts column information and literal value from a comparison.
1102    ///
1103    /// Returns (label, column_name, numeric_value) if the expression is
1104    /// a comparison between a property access and a numeric literal.
1105    fn extract_column_and_value(
1106        &self,
1107        left: &LogicalExpression,
1108        right: &LogicalExpression,
1109    ) -> Option<(String, String, f64)> {
1110        // Try left as property, right as literal
1111        if let Some(result) = self.try_extract_property_literal(left, right) {
1112            return Some(result);
1113        }
1114
1115        // Try right as property, left as literal
1116        self.try_extract_property_literal(right, left)
1117    }
1118
1119    /// Tries to extract property and literal from a specific ordering.
1120    fn try_extract_property_literal(
1121        &self,
1122        property_expr: &LogicalExpression,
1123        literal_expr: &LogicalExpression,
1124    ) -> Option<(String, String, f64)> {
1125        // Extract property access
1126        let (variable, property) = match property_expr {
1127            LogicalExpression::Property { variable, property } => {
1128                (variable.clone(), property.clone())
1129            }
1130            _ => return None,
1131        };
1132
1133        // Extract numeric literal
1134        let value = match literal_expr {
1135            LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1136            LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1137            _ => return None,
1138        };
1139
1140        // Try to find a label for this variable from table stats
1141        // Use the variable name as a heuristic label lookup
1142        // In practice, the optimizer would track which labels variables are bound to
1143        for label in self.table_stats.keys() {
1144            if let Some(stats) = self.table_stats.get(label)
1145                && stats.columns.contains_key(&property)
1146            {
1147                return Some((label.clone(), property, value));
1148            }
1149        }
1150
1151        // If no stats found but we have the property, return with variable as label
1152        Some((variable, property, value))
1153    }
1154
1155    /// Estimates unary expression selectivity.
1156    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1157        match op {
1158            UnaryOp::Not => 1.0 - self.default_selectivity,
1159            UnaryOp::IsNull => self.selectivity_config.is_null,
1160            UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1161            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
1162        }
1163    }
1164
1165    /// Gets statistics for a column.
1166    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1167        self.table_stats.get(label)?.columns.get(column)
1168    }
1169}
1170
1171impl Default for CardinalityEstimator {
1172    fn default() -> Self {
1173        Self::new()
1174    }
1175}
1176
1177#[cfg(test)]
1178mod tests {
1179    use super::*;
1180    use crate::query::plan::{
1181        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1182        ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1183    };
1184    use grafeo_common::types::Value;
1185
1186    #[test]
1187    fn test_node_scan_with_stats() {
1188        let mut estimator = CardinalityEstimator::new();
1189        estimator.add_table_stats("Person", TableStats::new(5000));
1190
1191        let scan = LogicalOperator::NodeScan(NodeScanOp {
1192            variable: "n".to_string(),
1193            label: Some("Person".to_string()),
1194            input: None,
1195        });
1196
1197        let cardinality = estimator.estimate(&scan);
1198        assert!((cardinality - 5000.0).abs() < 0.001);
1199    }
1200
1201    #[test]
1202    fn test_filter_reduces_cardinality() {
1203        let mut estimator = CardinalityEstimator::new();
1204        estimator.add_table_stats("Person", TableStats::new(1000));
1205
1206        let filter = LogicalOperator::Filter(FilterOp {
1207            predicate: LogicalExpression::Binary {
1208                left: Box::new(LogicalExpression::Property {
1209                    variable: "n".to_string(),
1210                    property: "age".to_string(),
1211                }),
1212                op: BinaryOp::Eq,
1213                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1214            },
1215            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1216                variable: "n".to_string(),
1217                label: Some("Person".to_string()),
1218                input: None,
1219            })),
1220            pushdown_hint: None,
1221        });
1222
1223        let cardinality = estimator.estimate(&filter);
1224        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
1225        assert!(cardinality < 1000.0);
1226        assert!(cardinality >= 1.0);
1227    }
1228
1229    #[test]
1230    fn test_join_cardinality() {
1231        let mut estimator = CardinalityEstimator::new();
1232        estimator.add_table_stats("Person", TableStats::new(1000));
1233        estimator.add_table_stats("Company", TableStats::new(100));
1234
1235        let join = LogicalOperator::Join(JoinOp {
1236            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1237                variable: "p".to_string(),
1238                label: Some("Person".to_string()),
1239                input: None,
1240            })),
1241            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1242                variable: "c".to_string(),
1243                label: Some("Company".to_string()),
1244                input: None,
1245            })),
1246            join_type: JoinType::Inner,
1247            conditions: vec![JoinCondition {
1248                left: LogicalExpression::Property {
1249                    variable: "p".to_string(),
1250                    property: "company_id".to_string(),
1251                },
1252                right: LogicalExpression::Property {
1253                    variable: "c".to_string(),
1254                    property: "id".to_string(),
1255                },
1256            }],
1257        });
1258
1259        let cardinality = estimator.estimate(&join);
1260        // Should be less than cross product
1261        assert!(cardinality < 1000.0 * 100.0);
1262    }
1263
1264    #[test]
1265    fn test_limit_caps_cardinality() {
1266        let mut estimator = CardinalityEstimator::new();
1267        estimator.add_table_stats("Person", TableStats::new(1000));
1268
1269        let limit = LogicalOperator::Limit(LimitOp {
1270            count: 10.into(),
1271            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1272                variable: "n".to_string(),
1273                label: Some("Person".to_string()),
1274                input: None,
1275            })),
1276        });
1277
1278        let cardinality = estimator.estimate(&limit);
1279        assert!((cardinality - 10.0).abs() < 0.001);
1280    }
1281
1282    #[test]
1283    fn test_aggregate_reduces_cardinality() {
1284        let mut estimator = CardinalityEstimator::new();
1285        estimator.add_table_stats("Person", TableStats::new(1000));
1286
1287        // Global aggregation
1288        let global_agg = LogicalOperator::Aggregate(AggregateOp {
1289            group_by: vec![],
1290            aggregates: vec![],
1291            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1292                variable: "n".to_string(),
1293                label: Some("Person".to_string()),
1294                input: None,
1295            })),
1296            having: None,
1297        });
1298
1299        let cardinality = estimator.estimate(&global_agg);
1300        assert!((cardinality - 1.0).abs() < 0.001);
1301
1302        // Group by aggregation
1303        let group_agg = LogicalOperator::Aggregate(AggregateOp {
1304            group_by: vec![LogicalExpression::Property {
1305                variable: "n".to_string(),
1306                property: "city".to_string(),
1307            }],
1308            aggregates: vec![],
1309            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1310                variable: "n".to_string(),
1311                label: Some("Person".to_string()),
1312                input: None,
1313            })),
1314            having: None,
1315        });
1316
1317        let cardinality = estimator.estimate(&group_agg);
1318        // Should be less than input
1319        assert!(cardinality < 1000.0);
1320    }
1321
1322    #[test]
1323    fn test_node_scan_without_stats() {
1324        let estimator = CardinalityEstimator::new();
1325
1326        let scan = LogicalOperator::NodeScan(NodeScanOp {
1327            variable: "n".to_string(),
1328            label: Some("Unknown".to_string()),
1329            input: None,
1330        });
1331
1332        let cardinality = estimator.estimate(&scan);
1333        // Should return default (1000)
1334        assert!((cardinality - 1000.0).abs() < 0.001);
1335    }
1336
1337    #[test]
1338    fn test_node_scan_no_label() {
1339        let estimator = CardinalityEstimator::new();
1340
1341        let scan = LogicalOperator::NodeScan(NodeScanOp {
1342            variable: "n".to_string(),
1343            label: None,
1344            input: None,
1345        });
1346
1347        let cardinality = estimator.estimate(&scan);
1348        // Should scan all nodes (default)
1349        assert!((cardinality - 1000.0).abs() < 0.001);
1350    }
1351
1352    #[test]
1353    fn test_filter_inequality_selectivity() {
1354        let mut estimator = CardinalityEstimator::new();
1355        estimator.add_table_stats("Person", TableStats::new(1000));
1356
1357        let filter = LogicalOperator::Filter(FilterOp {
1358            predicate: LogicalExpression::Binary {
1359                left: Box::new(LogicalExpression::Property {
1360                    variable: "n".to_string(),
1361                    property: "age".to_string(),
1362                }),
1363                op: BinaryOp::Ne,
1364                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1365            },
1366            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1367                variable: "n".to_string(),
1368                label: Some("Person".to_string()),
1369                input: None,
1370            })),
1371            pushdown_hint: None,
1372        });
1373
1374        let cardinality = estimator.estimate(&filter);
1375        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
1376        assert!(cardinality > 900.0);
1377    }
1378
1379    #[test]
1380    fn test_filter_range_selectivity() {
1381        let mut estimator = CardinalityEstimator::new();
1382        estimator.add_table_stats("Person", TableStats::new(1000));
1383
1384        let filter = LogicalOperator::Filter(FilterOp {
1385            predicate: LogicalExpression::Binary {
1386                left: Box::new(LogicalExpression::Property {
1387                    variable: "n".to_string(),
1388                    property: "age".to_string(),
1389                }),
1390                op: BinaryOp::Gt,
1391                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1392            },
1393            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1394                variable: "n".to_string(),
1395                label: Some("Person".to_string()),
1396                input: None,
1397            })),
1398            pushdown_hint: None,
1399        });
1400
1401        let cardinality = estimator.estimate(&filter);
1402        // Range selectivity is 0.33, so 1000 * 0.33 = 330
1403        assert!(cardinality < 500.0);
1404        assert!(cardinality > 100.0);
1405    }
1406
1407    #[test]
1408    fn test_filter_and_selectivity() {
1409        let mut estimator = CardinalityEstimator::new();
1410        estimator.add_table_stats("Person", TableStats::new(1000));
1411
1412        // Test AND with two equality predicates
1413        // Each equality has selectivity 0.01, so AND gives 0.01 * 0.01 = 0.0001
1414        let filter = LogicalOperator::Filter(FilterOp {
1415            predicate: LogicalExpression::Binary {
1416                left: Box::new(LogicalExpression::Binary {
1417                    left: Box::new(LogicalExpression::Property {
1418                        variable: "n".to_string(),
1419                        property: "city".to_string(),
1420                    }),
1421                    op: BinaryOp::Eq,
1422                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1423                }),
1424                op: BinaryOp::And,
1425                right: Box::new(LogicalExpression::Binary {
1426                    left: Box::new(LogicalExpression::Property {
1427                        variable: "n".to_string(),
1428                        property: "age".to_string(),
1429                    }),
1430                    op: BinaryOp::Eq,
1431                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1432                }),
1433            },
1434            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1435                variable: "n".to_string(),
1436                label: Some("Person".to_string()),
1437                input: None,
1438            })),
1439            pushdown_hint: None,
1440        });
1441
1442        let cardinality = estimator.estimate(&filter);
1443        // AND reduces selectivity (multiply): 0.01 * 0.01 = 0.0001
1444        // 1000 * 0.0001 = 0.1, min is 1.0
1445        assert!(cardinality < 100.0);
1446        assert!(cardinality >= 1.0);
1447    }
1448
1449    #[test]
1450    fn test_filter_or_selectivity() {
1451        let mut estimator = CardinalityEstimator::new();
1452        estimator.add_table_stats("Person", TableStats::new(1000));
1453
1454        // Test OR with two equality predicates
1455        // Each equality has selectivity 0.01
1456        // OR gives: 0.01 + 0.01 - (0.01 * 0.01) = 0.0199
1457        let filter = LogicalOperator::Filter(FilterOp {
1458            predicate: LogicalExpression::Binary {
1459                left: Box::new(LogicalExpression::Binary {
1460                    left: Box::new(LogicalExpression::Property {
1461                        variable: "n".to_string(),
1462                        property: "city".to_string(),
1463                    }),
1464                    op: BinaryOp::Eq,
1465                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1466                }),
1467                op: BinaryOp::Or,
1468                right: Box::new(LogicalExpression::Binary {
1469                    left: Box::new(LogicalExpression::Property {
1470                        variable: "n".to_string(),
1471                        property: "city".to_string(),
1472                    }),
1473                    op: BinaryOp::Eq,
1474                    right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1475                }),
1476            },
1477            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1478                variable: "n".to_string(),
1479                label: Some("Person".to_string()),
1480                input: None,
1481            })),
1482            pushdown_hint: None,
1483        });
1484
1485        let cardinality = estimator.estimate(&filter);
1486        // OR: 0.01 + 0.01 - 0.0001 ≈ 0.0199, so 1000 * 0.0199 ≈ 19.9
1487        assert!(cardinality < 100.0);
1488        assert!(cardinality >= 1.0);
1489    }
1490
1491    #[test]
1492    fn test_filter_literal_true() {
1493        let mut estimator = CardinalityEstimator::new();
1494        estimator.add_table_stats("Person", TableStats::new(1000));
1495
1496        let filter = LogicalOperator::Filter(FilterOp {
1497            predicate: LogicalExpression::Literal(Value::Bool(true)),
1498            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1499                variable: "n".to_string(),
1500                label: Some("Person".to_string()),
1501                input: None,
1502            })),
1503            pushdown_hint: None,
1504        });
1505
1506        let cardinality = estimator.estimate(&filter);
1507        // Literal true has selectivity 1.0
1508        assert!((cardinality - 1000.0).abs() < 0.001);
1509    }
1510
1511    #[test]
1512    fn test_filter_literal_false() {
1513        let mut estimator = CardinalityEstimator::new();
1514        estimator.add_table_stats("Person", TableStats::new(1000));
1515
1516        let filter = LogicalOperator::Filter(FilterOp {
1517            predicate: LogicalExpression::Literal(Value::Bool(false)),
1518            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1519                variable: "n".to_string(),
1520                label: Some("Person".to_string()),
1521                input: None,
1522            })),
1523            pushdown_hint: None,
1524        });
1525
1526        let cardinality = estimator.estimate(&filter);
1527        // Literal false has selectivity 0.0, but min is 1.0
1528        assert!((cardinality - 1.0).abs() < 0.001);
1529    }
1530
1531    #[test]
1532    fn test_unary_not_selectivity() {
1533        let mut estimator = CardinalityEstimator::new();
1534        estimator.add_table_stats("Person", TableStats::new(1000));
1535
1536        let filter = LogicalOperator::Filter(FilterOp {
1537            predicate: LogicalExpression::Unary {
1538                op: UnaryOp::Not,
1539                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1540            },
1541            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1542                variable: "n".to_string(),
1543                label: Some("Person".to_string()),
1544                input: None,
1545            })),
1546            pushdown_hint: None,
1547        });
1548
1549        let cardinality = estimator.estimate(&filter);
1550        // NOT inverts selectivity
1551        assert!(cardinality < 1000.0);
1552    }
1553
1554    #[test]
1555    fn test_unary_is_null_selectivity() {
1556        let mut estimator = CardinalityEstimator::new();
1557        estimator.add_table_stats("Person", TableStats::new(1000));
1558
1559        let filter = LogicalOperator::Filter(FilterOp {
1560            predicate: LogicalExpression::Unary {
1561                op: UnaryOp::IsNull,
1562                operand: Box::new(LogicalExpression::Variable("x".to_string())),
1563            },
1564            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1565                variable: "n".to_string(),
1566                label: Some("Person".to_string()),
1567                input: None,
1568            })),
1569            pushdown_hint: None,
1570        });
1571
1572        let cardinality = estimator.estimate(&filter);
1573        // IS NULL has selectivity 0.05
1574        assert!(cardinality < 100.0);
1575    }
1576
1577    #[test]
1578    fn test_expand_cardinality() {
1579        let mut estimator = CardinalityEstimator::new();
1580        estimator.add_table_stats("Person", TableStats::new(100));
1581
1582        let expand = LogicalOperator::Expand(ExpandOp {
1583            from_variable: "a".to_string(),
1584            to_variable: "b".to_string(),
1585            edge_variable: None,
1586            direction: ExpandDirection::Outgoing,
1587            edge_types: vec![],
1588            min_hops: 1,
1589            max_hops: Some(1),
1590            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1591                variable: "a".to_string(),
1592                label: Some("Person".to_string()),
1593                input: None,
1594            })),
1595            path_alias: None,
1596            path_mode: PathMode::Walk,
1597        });
1598
1599        let cardinality = estimator.estimate(&expand);
1600        // Expand multiplies by fanout (10)
1601        assert!(cardinality > 100.0);
1602    }
1603
1604    #[test]
1605    fn test_expand_with_edge_type_filter() {
1606        let mut estimator = CardinalityEstimator::new();
1607        estimator.add_table_stats("Person", TableStats::new(100));
1608
1609        let expand = LogicalOperator::Expand(ExpandOp {
1610            from_variable: "a".to_string(),
1611            to_variable: "b".to_string(),
1612            edge_variable: None,
1613            direction: ExpandDirection::Outgoing,
1614            edge_types: vec!["KNOWS".to_string()],
1615            min_hops: 1,
1616            max_hops: Some(1),
1617            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1618                variable: "a".to_string(),
1619                label: Some("Person".to_string()),
1620                input: None,
1621            })),
1622            path_alias: None,
1623            path_mode: PathMode::Walk,
1624        });
1625
1626        let cardinality = estimator.estimate(&expand);
1627        // With edge type, fanout is reduced by half
1628        assert!(cardinality > 100.0);
1629    }
1630
1631    #[test]
1632    fn test_expand_variable_length() {
1633        let mut estimator = CardinalityEstimator::new();
1634        estimator.add_table_stats("Person", TableStats::new(100));
1635
1636        let expand = LogicalOperator::Expand(ExpandOp {
1637            from_variable: "a".to_string(),
1638            to_variable: "b".to_string(),
1639            edge_variable: None,
1640            direction: ExpandDirection::Outgoing,
1641            edge_types: vec![],
1642            min_hops: 1,
1643            max_hops: Some(3),
1644            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1645                variable: "a".to_string(),
1646                label: Some("Person".to_string()),
1647                input: None,
1648            })),
1649            path_alias: None,
1650            path_mode: PathMode::Walk,
1651        });
1652
1653        let cardinality = estimator.estimate(&expand);
1654        // Variable length path has much higher cardinality
1655        assert!(cardinality > 500.0);
1656    }
1657
1658    #[test]
1659    fn test_join_cross_product() {
1660        let mut estimator = CardinalityEstimator::new();
1661        estimator.add_table_stats("Person", TableStats::new(100));
1662        estimator.add_table_stats("Company", TableStats::new(50));
1663
1664        let join = LogicalOperator::Join(JoinOp {
1665            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1666                variable: "p".to_string(),
1667                label: Some("Person".to_string()),
1668                input: None,
1669            })),
1670            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1671                variable: "c".to_string(),
1672                label: Some("Company".to_string()),
1673                input: None,
1674            })),
1675            join_type: JoinType::Cross,
1676            conditions: vec![],
1677        });
1678
1679        let cardinality = estimator.estimate(&join);
1680        // Cross join = 100 * 50 = 5000
1681        assert!((cardinality - 5000.0).abs() < 0.001);
1682    }
1683
1684    #[test]
1685    fn test_join_left_outer() {
1686        let mut estimator = CardinalityEstimator::new();
1687        estimator.add_table_stats("Person", TableStats::new(1000));
1688        estimator.add_table_stats("Company", TableStats::new(10));
1689
1690        let join = LogicalOperator::Join(JoinOp {
1691            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1692                variable: "p".to_string(),
1693                label: Some("Person".to_string()),
1694                input: None,
1695            })),
1696            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1697                variable: "c".to_string(),
1698                label: Some("Company".to_string()),
1699                input: None,
1700            })),
1701            join_type: JoinType::Left,
1702            conditions: vec![JoinCondition {
1703                left: LogicalExpression::Variable("p".to_string()),
1704                right: LogicalExpression::Variable("c".to_string()),
1705            }],
1706        });
1707
1708        let cardinality = estimator.estimate(&join);
1709        // Left join returns at least all left rows
1710        assert!(cardinality >= 1000.0);
1711    }
1712
1713    #[test]
1714    fn test_join_semi() {
1715        let mut estimator = CardinalityEstimator::new();
1716        estimator.add_table_stats("Person", TableStats::new(1000));
1717        estimator.add_table_stats("Company", TableStats::new(100));
1718
1719        let join = LogicalOperator::Join(JoinOp {
1720            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1721                variable: "p".to_string(),
1722                label: Some("Person".to_string()),
1723                input: None,
1724            })),
1725            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1726                variable: "c".to_string(),
1727                label: Some("Company".to_string()),
1728                input: None,
1729            })),
1730            join_type: JoinType::Semi,
1731            conditions: vec![],
1732        });
1733
1734        let cardinality = estimator.estimate(&join);
1735        // Semi join returns at most left cardinality
1736        assert!(cardinality <= 1000.0);
1737    }
1738
1739    #[test]
1740    fn test_join_anti() {
1741        let mut estimator = CardinalityEstimator::new();
1742        estimator.add_table_stats("Person", TableStats::new(1000));
1743        estimator.add_table_stats("Company", TableStats::new(100));
1744
1745        let join = LogicalOperator::Join(JoinOp {
1746            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1747                variable: "p".to_string(),
1748                label: Some("Person".to_string()),
1749                input: None,
1750            })),
1751            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1752                variable: "c".to_string(),
1753                label: Some("Company".to_string()),
1754                input: None,
1755            })),
1756            join_type: JoinType::Anti,
1757            conditions: vec![],
1758        });
1759
1760        let cardinality = estimator.estimate(&join);
1761        // Anti join returns at most left cardinality
1762        assert!(cardinality <= 1000.0);
1763        assert!(cardinality >= 1.0);
1764    }
1765
1766    #[test]
1767    fn test_project_preserves_cardinality() {
1768        let mut estimator = CardinalityEstimator::new();
1769        estimator.add_table_stats("Person", TableStats::new(1000));
1770
1771        let project = LogicalOperator::Project(ProjectOp {
1772            projections: vec![Projection {
1773                expression: LogicalExpression::Variable("n".to_string()),
1774                alias: None,
1775            }],
1776            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1777                variable: "n".to_string(),
1778                label: Some("Person".to_string()),
1779                input: None,
1780            })),
1781            pass_through_input: false,
1782        });
1783
1784        let cardinality = estimator.estimate(&project);
1785        assert!((cardinality - 1000.0).abs() < 0.001);
1786    }
1787
1788    #[test]
1789    fn test_sort_preserves_cardinality() {
1790        let mut estimator = CardinalityEstimator::new();
1791        estimator.add_table_stats("Person", TableStats::new(1000));
1792
1793        let sort = LogicalOperator::Sort(SortOp {
1794            keys: vec![SortKey {
1795                expression: LogicalExpression::Variable("n".to_string()),
1796                order: SortOrder::Ascending,
1797                nulls: None,
1798            }],
1799            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1800                variable: "n".to_string(),
1801                label: Some("Person".to_string()),
1802                input: None,
1803            })),
1804        });
1805
1806        let cardinality = estimator.estimate(&sort);
1807        assert!((cardinality - 1000.0).abs() < 0.001);
1808    }
1809
1810    #[test]
1811    fn test_distinct_reduces_cardinality() {
1812        let mut estimator = CardinalityEstimator::new();
1813        estimator.add_table_stats("Person", TableStats::new(1000));
1814
1815        let distinct = LogicalOperator::Distinct(DistinctOp {
1816            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1817                variable: "n".to_string(),
1818                label: Some("Person".to_string()),
1819                input: None,
1820            })),
1821            columns: None,
1822        });
1823
1824        let cardinality = estimator.estimate(&distinct);
1825        // Distinct assumes 50% unique
1826        assert!((cardinality - 500.0).abs() < 0.001);
1827    }
1828
1829    #[test]
1830    fn test_skip_reduces_cardinality() {
1831        let mut estimator = CardinalityEstimator::new();
1832        estimator.add_table_stats("Person", TableStats::new(1000));
1833
1834        let skip = LogicalOperator::Skip(SkipOp {
1835            count: 100.into(),
1836            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1837                variable: "n".to_string(),
1838                label: Some("Person".to_string()),
1839                input: None,
1840            })),
1841        });
1842
1843        let cardinality = estimator.estimate(&skip);
1844        assert!((cardinality - 900.0).abs() < 0.001);
1845    }
1846
1847    #[test]
1848    fn test_return_preserves_cardinality() {
1849        let mut estimator = CardinalityEstimator::new();
1850        estimator.add_table_stats("Person", TableStats::new(1000));
1851
1852        let ret = LogicalOperator::Return(ReturnOp {
1853            items: vec![ReturnItem {
1854                expression: LogicalExpression::Variable("n".to_string()),
1855                alias: None,
1856            }],
1857            distinct: false,
1858            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1859                variable: "n".to_string(),
1860                label: Some("Person".to_string()),
1861                input: None,
1862            })),
1863        });
1864
1865        let cardinality = estimator.estimate(&ret);
1866        assert!((cardinality - 1000.0).abs() < 0.001);
1867    }
1868
1869    #[test]
1870    fn test_empty_cardinality() {
1871        let estimator = CardinalityEstimator::new();
1872        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1873        assert!((cardinality).abs() < 0.001);
1874    }
1875
1876    #[test]
1877    fn test_table_stats_with_column() {
1878        let stats = TableStats::new(1000).with_column(
1879            "age",
1880            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1881        );
1882
1883        assert_eq!(stats.row_count, 1000);
1884        let col = stats.columns.get("age").unwrap();
1885        assert_eq!(col.distinct_count, 50);
1886        assert_eq!(col.null_count, 10);
1887        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1888        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1889    }
1890
1891    #[test]
1892    fn test_estimator_default() {
1893        let estimator = CardinalityEstimator::default();
1894        let scan = LogicalOperator::NodeScan(NodeScanOp {
1895            variable: "n".to_string(),
1896            label: None,
1897            input: None,
1898        });
1899        let cardinality = estimator.estimate(&scan);
1900        assert!((cardinality - 1000.0).abs() < 0.001);
1901    }
1902
1903    #[test]
1904    fn test_set_avg_fanout() {
1905        let mut estimator = CardinalityEstimator::new();
1906        estimator.add_table_stats("Person", TableStats::new(100));
1907        estimator.set_avg_fanout(5.0);
1908
1909        let expand = LogicalOperator::Expand(ExpandOp {
1910            from_variable: "a".to_string(),
1911            to_variable: "b".to_string(),
1912            edge_variable: None,
1913            direction: ExpandDirection::Outgoing,
1914            edge_types: vec![],
1915            min_hops: 1,
1916            max_hops: Some(1),
1917            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1918                variable: "a".to_string(),
1919                label: Some("Person".to_string()),
1920                input: None,
1921            })),
1922            path_alias: None,
1923            path_mode: PathMode::Walk,
1924        });
1925
1926        let cardinality = estimator.estimate(&expand);
1927        // With fanout 5: 100 * 5 = 500
1928        assert!((cardinality - 500.0).abs() < 0.001);
1929    }
1930
1931    #[test]
1932    fn test_multiple_group_by_keys_reduce_cardinality() {
1933        // The current implementation uses a simplified model where more group by keys
1934        // results in greater reduction (dividing by 10^num_keys). This is a simplification
1935        // that works for most cases where group by keys are correlated.
1936        let mut estimator = CardinalityEstimator::new();
1937        estimator.add_table_stats("Person", TableStats::new(10000));
1938
1939        let single_group = LogicalOperator::Aggregate(AggregateOp {
1940            group_by: vec![LogicalExpression::Property {
1941                variable: "n".to_string(),
1942                property: "city".to_string(),
1943            }],
1944            aggregates: vec![],
1945            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1946                variable: "n".to_string(),
1947                label: Some("Person".to_string()),
1948                input: None,
1949            })),
1950            having: None,
1951        });
1952
1953        let multi_group = LogicalOperator::Aggregate(AggregateOp {
1954            group_by: vec![
1955                LogicalExpression::Property {
1956                    variable: "n".to_string(),
1957                    property: "city".to_string(),
1958                },
1959                LogicalExpression::Property {
1960                    variable: "n".to_string(),
1961                    property: "country".to_string(),
1962                },
1963            ],
1964            aggregates: vec![],
1965            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1966                variable: "n".to_string(),
1967                label: Some("Person".to_string()),
1968                input: None,
1969            })),
1970            having: None,
1971        });
1972
1973        let single_card = estimator.estimate(&single_group);
1974        let multi_card = estimator.estimate(&multi_group);
1975
1976        // Both should reduce cardinality from input
1977        assert!(single_card < 10000.0);
1978        assert!(multi_card < 10000.0);
1979        // Both should be at least 1
1980        assert!(single_card >= 1.0);
1981        assert!(multi_card >= 1.0);
1982    }
1983
1984    // ============= Histogram Tests =============
1985
1986    #[test]
1987    fn test_histogram_build_uniform() {
1988        // Build histogram from uniformly distributed data
1989        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1990        let histogram = EquiDepthHistogram::build(&values, 10);
1991
1992        assert_eq!(histogram.num_buckets(), 10);
1993        assert_eq!(histogram.total_rows(), 100);
1994
1995        // Each bucket should have approximately 10 rows
1996        for bucket in histogram.buckets() {
1997            assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1998        }
1999    }
2000
2001    #[test]
2002    fn test_histogram_build_skewed() {
2003        // Build histogram from skewed data (many small values, few large)
2004        let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
2005        values.extend((0..20).map(|i| 1000.0 + i as f64));
2006        let histogram = EquiDepthHistogram::build(&values, 5);
2007
2008        assert_eq!(histogram.num_buckets(), 5);
2009        assert_eq!(histogram.total_rows(), 100);
2010
2011        // Each bucket should have ~20 rows despite skewed data
2012        for bucket in histogram.buckets() {
2013            assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
2014        }
2015    }
2016
2017    #[test]
2018    fn test_histogram_range_selectivity_full() {
2019        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2020        let histogram = EquiDepthHistogram::build(&values, 10);
2021
2022        // Full range should have selectivity ~1.0
2023        let selectivity = histogram.range_selectivity(None, None);
2024        assert!((selectivity - 1.0).abs() < 0.01);
2025    }
2026
2027    #[test]
2028    fn test_histogram_range_selectivity_half() {
2029        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2030        let histogram = EquiDepthHistogram::build(&values, 10);
2031
2032        // Values >= 50 should be ~50% (half the data)
2033        let selectivity = histogram.range_selectivity(Some(50.0), None);
2034        assert!(selectivity > 0.4 && selectivity < 0.6);
2035    }
2036
2037    #[test]
2038    fn test_histogram_range_selectivity_quarter() {
2039        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2040        let histogram = EquiDepthHistogram::build(&values, 10);
2041
2042        // Values < 25 should be ~25%
2043        let selectivity = histogram.range_selectivity(None, Some(25.0));
2044        assert!(selectivity > 0.2 && selectivity < 0.3);
2045    }
2046
2047    #[test]
2048    fn test_histogram_equality_selectivity() {
2049        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2050        let histogram = EquiDepthHistogram::build(&values, 10);
2051
2052        // Equality on 100 distinct values should be ~1%
2053        let selectivity = histogram.equality_selectivity(50.0);
2054        assert!(selectivity > 0.005 && selectivity < 0.02);
2055    }
2056
2057    #[test]
2058    fn test_histogram_empty() {
2059        let histogram = EquiDepthHistogram::build(&[], 10);
2060
2061        assert_eq!(histogram.num_buckets(), 0);
2062        assert_eq!(histogram.total_rows(), 0);
2063
2064        // Default selectivity for empty histogram
2065        let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
2066        assert!((selectivity - 0.33).abs() < 0.01);
2067    }
2068
2069    #[test]
2070    fn test_histogram_bucket_overlap() {
2071        let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
2072
2073        // Full overlap
2074        assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
2075
2076        // Half overlap (lower half)
2077        assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
2078
2079        // Half overlap (upper half)
2080        assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
2081
2082        // No overlap (below)
2083        assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
2084
2085        // No overlap (above)
2086        assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
2087    }
2088
2089    #[test]
2090    fn test_column_stats_from_values() {
2091        let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2092        let stats = ColumnStats::from_values(values, 4);
2093
2094        assert_eq!(stats.distinct_count, 5); // 10, 20, 30, 40, 50
2095        assert!(stats.min_value.is_some());
2096        assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2097        assert!(stats.max_value.is_some());
2098        assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2099        assert!(stats.histogram.is_some());
2100    }
2101
2102    #[test]
2103    fn test_column_stats_with_histogram_builder() {
2104        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2105        let histogram = EquiDepthHistogram::build(&values, 10);
2106
2107        let stats = ColumnStats::new(100)
2108            .with_range(0.0, 99.0)
2109            .with_histogram(histogram);
2110
2111        assert!(stats.histogram.is_some());
2112        assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2113    }
2114
2115    #[test]
2116    fn test_filter_with_histogram_stats() {
2117        let mut estimator = CardinalityEstimator::new();
2118
2119        // Create stats with histogram for age column
2120        let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2121        let histogram = EquiDepthHistogram::build(&age_values, 10);
2122        let age_stats = ColumnStats::new(62)
2123            .with_range(18.0, 79.0)
2124            .with_histogram(histogram);
2125
2126        estimator.add_table_stats(
2127            "Person",
2128            TableStats::new(1000).with_column("age", age_stats),
2129        );
2130
2131        // Filter: age > 50
2132        // Age range is 18-79, so >50 is about (79-50)/(79-18) = 29/61 ≈ 47.5%
2133        let filter = LogicalOperator::Filter(FilterOp {
2134            predicate: LogicalExpression::Binary {
2135                left: Box::new(LogicalExpression::Property {
2136                    variable: "n".to_string(),
2137                    property: "age".to_string(),
2138                }),
2139                op: BinaryOp::Gt,
2140                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2141            },
2142            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2143                variable: "n".to_string(),
2144                label: Some("Person".to_string()),
2145                input: None,
2146            })),
2147            pushdown_hint: None,
2148        });
2149
2150        let cardinality = estimator.estimate(&filter);
2151
2152        // With histogram, should get more accurate estimate than default 0.33
2153        // Expected: ~47.5% of 1000 = ~475
2154        assert!(cardinality > 300.0 && cardinality < 600.0);
2155    }
2156
2157    #[test]
2158    fn test_filter_equality_with_histogram() {
2159        let mut estimator = CardinalityEstimator::new();
2160
2161        // Create stats with histogram
2162        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2163        let histogram = EquiDepthHistogram::build(&values, 10);
2164        let stats = ColumnStats::new(100)
2165            .with_range(0.0, 99.0)
2166            .with_histogram(histogram);
2167
2168        estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2169
2170        // Filter: value = 50
2171        let filter = LogicalOperator::Filter(FilterOp {
2172            predicate: LogicalExpression::Binary {
2173                left: Box::new(LogicalExpression::Property {
2174                    variable: "d".to_string(),
2175                    property: "value".to_string(),
2176                }),
2177                op: BinaryOp::Eq,
2178                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2179            },
2180            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2181                variable: "d".to_string(),
2182                label: Some("Data".to_string()),
2183                input: None,
2184            })),
2185            pushdown_hint: None,
2186        });
2187
2188        let cardinality = estimator.estimate(&filter);
2189
2190        // With 100 distinct values, selectivity should be ~1%
2191        // 1000 * 0.01 = 10
2192        assert!((1.0..50.0).contains(&cardinality));
2193    }
2194
2195    #[test]
2196    fn test_histogram_min_max() {
2197        let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2198        let histogram = EquiDepthHistogram::build(&values, 2);
2199
2200        assert_eq!(histogram.min_value(), Some(5.0));
2201        // Max is the upper bound of the last bucket
2202        assert!(histogram.max_value().is_some());
2203    }
2204
2205    // ==================== SelectivityConfig Tests ====================
2206
2207    #[test]
2208    fn test_selectivity_config_defaults() {
2209        let config = SelectivityConfig::new();
2210        assert!((config.default - 0.1).abs() < f64::EPSILON);
2211        assert!((config.equality - 0.01).abs() < f64::EPSILON);
2212        assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2213        assert!((config.range - 0.33).abs() < f64::EPSILON);
2214        assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2215        assert!((config.membership - 0.1).abs() < f64::EPSILON);
2216        assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2217        assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2218        assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2219    }
2220
2221    #[test]
2222    fn test_custom_selectivity_config() {
2223        let config = SelectivityConfig {
2224            equality: 0.05,
2225            range: 0.25,
2226            ..SelectivityConfig::new()
2227        };
2228        let estimator = CardinalityEstimator::with_selectivity_config(config);
2229        assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2230        assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2231    }
2232
2233    #[test]
2234    fn test_custom_selectivity_affects_estimation() {
2235        // Default: equality = 0.01 → 1000 * 0.01 = 10
2236        let mut default_est = CardinalityEstimator::new();
2237        default_est.add_table_stats("Person", TableStats::new(1000));
2238
2239        let filter = LogicalOperator::Filter(FilterOp {
2240            predicate: LogicalExpression::Binary {
2241                left: Box::new(LogicalExpression::Property {
2242                    variable: "n".to_string(),
2243                    property: "name".to_string(),
2244                }),
2245                op: BinaryOp::Eq,
2246                right: Box::new(LogicalExpression::Literal(Value::String("Alix".into()))),
2247            },
2248            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2249                variable: "n".to_string(),
2250                label: Some("Person".to_string()),
2251                input: None,
2252            })),
2253            pushdown_hint: None,
2254        });
2255
2256        let default_card = default_est.estimate(&filter);
2257
2258        // Custom: equality = 0.2 → 1000 * 0.2 = 200
2259        let config = SelectivityConfig {
2260            equality: 0.2,
2261            ..SelectivityConfig::new()
2262        };
2263        let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2264        custom_est.add_table_stats("Person", TableStats::new(1000));
2265
2266        let custom_card = custom_est.estimate(&filter);
2267
2268        assert!(custom_card > default_card);
2269        assert!((custom_card - 200.0).abs() < 1.0);
2270    }
2271
2272    #[test]
2273    fn test_custom_range_selectivity() {
2274        let config = SelectivityConfig {
2275            range: 0.5,
2276            ..SelectivityConfig::new()
2277        };
2278        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2279        estimator.add_table_stats("Person", TableStats::new(1000));
2280
2281        let filter = LogicalOperator::Filter(FilterOp {
2282            predicate: LogicalExpression::Binary {
2283                left: Box::new(LogicalExpression::Property {
2284                    variable: "n".to_string(),
2285                    property: "age".to_string(),
2286                }),
2287                op: BinaryOp::Gt,
2288                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2289            },
2290            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2291                variable: "n".to_string(),
2292                label: Some("Person".to_string()),
2293                input: None,
2294            })),
2295            pushdown_hint: None,
2296        });
2297
2298        let cardinality = estimator.estimate(&filter);
2299        // 1000 * 0.5 = 500
2300        assert!((cardinality - 500.0).abs() < 1.0);
2301    }
2302
2303    #[test]
2304    fn test_custom_distinct_fraction() {
2305        let config = SelectivityConfig {
2306            distinct_fraction: 0.8,
2307            ..SelectivityConfig::new()
2308        };
2309        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2310        estimator.add_table_stats("Person", TableStats::new(1000));
2311
2312        let distinct = LogicalOperator::Distinct(DistinctOp {
2313            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2314                variable: "n".to_string(),
2315                label: Some("Person".to_string()),
2316                input: None,
2317            })),
2318            columns: None,
2319        });
2320
2321        let cardinality = estimator.estimate(&distinct);
2322        // 1000 * 0.8 = 800
2323        assert!((cardinality - 800.0).abs() < 1.0);
2324    }
2325
2326    // ==================== EstimationLog Tests ====================
2327
2328    #[test]
2329    fn test_estimation_log_basic() {
2330        let mut log = EstimationLog::new(10.0);
2331        log.record("NodeScan(Person)", 1000.0, 1200.0);
2332        log.record("Filter(age > 30)", 100.0, 90.0);
2333
2334        assert_eq!(log.entries().len(), 2);
2335        assert!(!log.should_replan()); // 1.2x and 0.9x are within 10x threshold
2336    }
2337
2338    #[test]
2339    fn test_estimation_log_triggers_replan() {
2340        let mut log = EstimationLog::new(10.0);
2341        log.record("NodeScan(Person)", 100.0, 5000.0); // 50x underestimate
2342
2343        assert!(log.should_replan());
2344    }
2345
2346    #[test]
2347    fn test_estimation_log_overestimate_triggers_replan() {
2348        let mut log = EstimationLog::new(5.0);
2349        log.record("Filter", 1000.0, 100.0); // 10x overestimate → ratio = 0.1
2350
2351        assert!(log.should_replan()); // 0.1 < 1/5 = 0.2
2352    }
2353
2354    #[test]
2355    fn test_estimation_entry_error_ratio() {
2356        let entry = EstimationEntry {
2357            operator: "test".into(),
2358            estimated: 100.0,
2359            actual: 200.0,
2360        };
2361        assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2362
2363        let perfect = EstimationEntry {
2364            operator: "test".into(),
2365            estimated: 100.0,
2366            actual: 100.0,
2367        };
2368        assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2369
2370        let zero_est = EstimationEntry {
2371            operator: "test".into(),
2372            estimated: 0.0,
2373            actual: 0.0,
2374        };
2375        assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2376    }
2377
2378    #[test]
2379    fn test_estimation_log_max_error_ratio() {
2380        let mut log = EstimationLog::new(10.0);
2381        log.record("A", 100.0, 300.0); // 3x
2382        log.record("B", 100.0, 50.0); // 2x (normalized: 1/0.5 = 2)
2383        log.record("C", 100.0, 100.0); // 1x
2384
2385        assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2386    }
2387
2388    #[test]
2389    fn test_estimation_log_clear() {
2390        let mut log = EstimationLog::new(10.0);
2391        log.record("A", 100.0, 100.0);
2392        assert_eq!(log.entries().len(), 1);
2393
2394        log.clear();
2395        assert!(log.entries().is_empty());
2396        assert!(!log.should_replan());
2397    }
2398
2399    #[test]
2400    fn test_create_estimation_log() {
2401        let log = CardinalityEstimator::create_estimation_log();
2402        assert!(log.entries().is_empty());
2403        assert!(!log.should_replan());
2404    }
2405}