Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4//!
5//! # Equi-Depth Histograms
6//!
7//! This module provides equi-depth histogram support for accurate selectivity
8//! estimation of range predicates. Unlike equi-width histograms that divide
9//! the value range into equal-sized buckets, equi-depth histograms divide
10//! the data into buckets with approximately equal numbers of rows.
11//!
12//! Benefits:
13//! - Better estimates for skewed data distributions
14//! - More accurate range selectivity than assuming uniform distribution
15//! - Adaptive to actual data characteristics
16
17use crate::query::plan::{
18    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp, LimitOp,
19    LogicalExpression, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, SkipOp, SortOp,
20    TripleComponent, TripleScanOp, UnaryOp, VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24/// A bucket in an equi-depth histogram.
25///
26/// Each bucket represents a range of values and the frequency of rows
27/// falling within that range. In an equi-depth histogram, all buckets
28/// contain approximately the same number of rows.
29#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31    /// Lower bound of the bucket (inclusive).
32    pub lower_bound: f64,
33    /// Upper bound of the bucket (exclusive, except for the last bucket).
34    pub upper_bound: f64,
35    /// Number of rows in this bucket.
36    pub frequency: u64,
37    /// Number of distinct values in this bucket.
38    pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42    /// Creates a new histogram bucket.
43    #[must_use]
44    pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45        Self {
46            lower_bound,
47            upper_bound,
48            frequency,
49            distinct_count,
50        }
51    }
52
53    /// Returns the width of this bucket.
54    #[must_use]
55    pub fn width(&self) -> f64 {
56        self.upper_bound - self.lower_bound
57    }
58
59    /// Checks if a value falls within this bucket.
60    #[must_use]
61    pub fn contains(&self, value: f64) -> bool {
62        value >= self.lower_bound && value < self.upper_bound
63    }
64
65    /// Returns the fraction of this bucket covered by the given range.
66    #[must_use]
67    pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68        let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69        let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71        let bucket_width = self.width();
72        if bucket_width <= 0.0 {
73            return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74                1.0
75            } else {
76                0.0
77            };
78        }
79
80        let overlap = (effective_upper - effective_lower).max(0.0);
81        (overlap / bucket_width).min(1.0)
82    }
83}
84
85/// An equi-depth histogram for selectivity estimation.
86///
87/// Equi-depth histograms partition data into buckets where each bucket
88/// contains approximately the same number of rows. This provides more
89/// accurate selectivity estimates than assuming uniform distribution,
90/// especially for skewed data.
91///
92/// # Example
93///
94/// ```no_run
95/// use grafeo_engine::query::optimizer::cardinality::EquiDepthHistogram;
96///
97/// // Build a histogram from sorted values
98/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0];
99/// let histogram = EquiDepthHistogram::build(&values, 4);
100///
101/// // Estimate selectivity for age > 25
102/// let selectivity = histogram.range_selectivity(Some(25.0), None);
103/// ```
104#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106    /// The histogram buckets, sorted by lower_bound.
107    buckets: Vec<HistogramBucket>,
108    /// Total number of rows represented by this histogram.
109    total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113    /// Creates a new histogram from pre-built buckets.
114    #[must_use]
115    pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116        let total_rows = buckets.iter().map(|b| b.frequency).sum();
117        Self {
118            buckets,
119            total_rows,
120        }
121    }
122
123    /// Builds an equi-depth histogram from a sorted slice of values.
124    ///
125    /// # Arguments
126    /// * `values` - A sorted slice of numeric values
127    /// * `num_buckets` - The desired number of buckets
128    ///
129    /// # Returns
130    /// An equi-depth histogram with approximately equal row counts per bucket.
131    #[must_use]
132    pub fn build(values: &[f64], num_buckets: usize) -> Self {
133        if values.is_empty() || num_buckets == 0 {
134            return Self {
135                buckets: Vec::new(),
136                total_rows: 0,
137            };
138        }
139
140        let num_buckets = num_buckets.min(values.len());
141        let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142        let mut buckets = Vec::with_capacity(num_buckets);
143
144        let mut start_idx = 0;
145        while start_idx < values.len() {
146            let end_idx = (start_idx + rows_per_bucket).min(values.len());
147            let lower_bound = values[start_idx];
148            let upper_bound = if end_idx < values.len() {
149                values[end_idx]
150            } else {
151                // For the last bucket, extend slightly beyond the max value
152                values[end_idx - 1] + 1.0
153            };
154
155            // Count distinct values in this bucket
156            let bucket_values = &values[start_idx..end_idx];
157            let distinct_count = count_distinct(bucket_values);
158
159            buckets.push(HistogramBucket::new(
160                lower_bound,
161                upper_bound,
162                (end_idx - start_idx) as u64,
163                distinct_count,
164            ));
165
166            start_idx = end_idx;
167        }
168
169        Self::new(buckets)
170    }
171
172    /// Returns the number of buckets in this histogram.
173    #[must_use]
174    pub fn num_buckets(&self) -> usize {
175        self.buckets.len()
176    }
177
178    /// Returns the total number of rows represented.
179    #[must_use]
180    pub fn total_rows(&self) -> u64 {
181        self.total_rows
182    }
183
184    /// Returns the histogram buckets.
185    #[must_use]
186    pub fn buckets(&self) -> &[HistogramBucket] {
187        &self.buckets
188    }
189
190    /// Estimates selectivity for a range predicate.
191    ///
192    /// # Arguments
193    /// * `lower` - Lower bound (None for unbounded)
194    /// * `upper` - Upper bound (None for unbounded)
195    ///
196    /// # Returns
197    /// Estimated fraction of rows matching the range (0.0 to 1.0).
198    #[must_use]
199    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200        if self.buckets.is_empty() || self.total_rows == 0 {
201            return 0.33; // Default fallback
202        }
203
204        let mut matching_rows = 0.0;
205
206        for bucket in &self.buckets {
207            // Check if this bucket overlaps with the range
208            let bucket_lower = bucket.lower_bound;
209            let bucket_upper = bucket.upper_bound;
210
211            // Skip buckets entirely outside the range
212            if let Some(l) = lower
213                && bucket_upper <= l
214            {
215                continue;
216            }
217            if let Some(u) = upper
218                && bucket_lower >= u
219            {
220                continue;
221            }
222
223            // Calculate the fraction of this bucket covered by the range
224            let overlap = bucket.overlap_fraction(lower, upper);
225            matching_rows += overlap * bucket.frequency as f64;
226        }
227
228        (matching_rows / self.total_rows as f64).clamp(0.0, 1.0)
229    }
230
231    /// Estimates selectivity for an equality predicate.
232    ///
233    /// Uses the distinct count within matching buckets for better accuracy.
234    #[must_use]
235    pub fn equality_selectivity(&self, value: f64) -> f64 {
236        if self.buckets.is_empty() || self.total_rows == 0 {
237            return 0.01; // Default fallback
238        }
239
240        // Find the bucket containing this value
241        for bucket in &self.buckets {
242            if bucket.contains(value) {
243                // Assume uniform distribution within bucket
244                if bucket.distinct_count > 0 {
245                    return (bucket.frequency as f64
246                        / bucket.distinct_count as f64
247                        / self.total_rows as f64)
248                        .min(1.0);
249                }
250            }
251        }
252
253        // Value not in any bucket - very low selectivity
254        0.001
255    }
256
257    /// Gets the minimum value in the histogram.
258    #[must_use]
259    pub fn min_value(&self) -> Option<f64> {
260        self.buckets.first().map(|b| b.lower_bound)
261    }
262
263    /// Gets the maximum value in the histogram.
264    #[must_use]
265    pub fn max_value(&self) -> Option<f64> {
266        self.buckets.last().map(|b| b.upper_bound)
267    }
268}
269
270/// Counts distinct values in a sorted slice.
271/// Counts the number of top-level AND conjuncts in an expression.
272///
273/// For `(A AND B) AND (C AND D)` returns 4; for a single non-AND expression
274/// returns 1. Used to estimate selectivity of multi-predicate join conditions.
275fn count_and_conjuncts(expr: &LogicalExpression) -> usize {
276    match expr {
277        LogicalExpression::Binary {
278            op: BinaryOp::And,
279            left,
280            right,
281        } => count_and_conjuncts(left) + count_and_conjuncts(right),
282        _ => 1,
283    }
284}
285
286fn count_distinct(sorted_values: &[f64]) -> u64 {
287    if sorted_values.is_empty() {
288        return 0;
289    }
290
291    let mut count = 1u64;
292    let mut prev = sorted_values[0];
293
294    for &val in &sorted_values[1..] {
295        if (val - prev).abs() > f64::EPSILON {
296            count += 1;
297            prev = val;
298        }
299    }
300
301    count
302}
303
304/// Statistics for a table/label.
305#[derive(Debug, Clone)]
306pub struct TableStats {
307    /// Total number of rows.
308    pub row_count: u64,
309    /// Column statistics.
310    pub columns: HashMap<String, ColumnStats>,
311}
312
313impl TableStats {
314    /// Creates new table statistics.
315    #[must_use]
316    pub fn new(row_count: u64) -> Self {
317        Self {
318            row_count,
319            columns: HashMap::new(),
320        }
321    }
322
323    /// Adds column statistics.
324    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
325        self.columns.insert(name.to_string(), stats);
326        self
327    }
328}
329
330/// Statistics for a column.
331#[derive(Debug, Clone)]
332pub struct ColumnStats {
333    /// Number of distinct values.
334    pub distinct_count: u64,
335    /// Number of null values.
336    pub null_count: u64,
337    /// Minimum value (if orderable).
338    pub min_value: Option<f64>,
339    /// Maximum value (if orderable).
340    pub max_value: Option<f64>,
341    /// Equi-depth histogram for accurate selectivity estimation.
342    pub histogram: Option<EquiDepthHistogram>,
343}
344
345impl ColumnStats {
346    /// Creates new column statistics.
347    #[must_use]
348    pub fn new(distinct_count: u64) -> Self {
349        Self {
350            distinct_count,
351            null_count: 0,
352            min_value: None,
353            max_value: None,
354            histogram: None,
355        }
356    }
357
358    /// Sets the null count.
359    #[must_use]
360    pub fn with_nulls(mut self, null_count: u64) -> Self {
361        self.null_count = null_count;
362        self
363    }
364
365    /// Sets the min/max range.
366    #[must_use]
367    pub fn with_range(mut self, min: f64, max: f64) -> Self {
368        self.min_value = Some(min);
369        self.max_value = Some(max);
370        self
371    }
372
373    /// Sets the equi-depth histogram for this column.
374    #[must_use]
375    pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
376        self.histogram = Some(histogram);
377        self
378    }
379
380    /// Builds column statistics with histogram from raw values.
381    ///
382    /// This is a convenience method that computes all statistics from the data.
383    ///
384    /// # Arguments
385    /// * `values` - The column values (will be sorted internally)
386    /// * `num_buckets` - Number of histogram buckets to create
387    #[must_use]
388    pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
389        if values.is_empty() {
390            return Self::new(0);
391        }
392
393        // Sort values for histogram building
394        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
395
396        let min = values.first().copied();
397        let max = values.last().copied();
398        let distinct_count = count_distinct(&values);
399        let histogram = EquiDepthHistogram::build(&values, num_buckets);
400
401        Self {
402            distinct_count,
403            null_count: 0,
404            min_value: min,
405            max_value: max,
406            histogram: Some(histogram),
407        }
408    }
409}
410
411/// Configurable selectivity defaults for cardinality estimation.
412///
413/// Controls the assumed selectivity for various predicate types when
414/// histogram or column statistics are unavailable. Adjusting these
415/// values can improve plan quality for workloads with known skew.
416#[derive(Debug, Clone)]
417pub struct SelectivityConfig {
418    /// Selectivity for unknown predicates (default: 0.1).
419    pub default: f64,
420    /// Selectivity for equality predicates without stats (default: 0.01).
421    pub equality: f64,
422    /// Selectivity for inequality predicates (default: 0.99).
423    pub inequality: f64,
424    /// Selectivity for range predicates without stats (default: 0.33).
425    pub range: f64,
426    /// Selectivity for string operations: STARTS WITH, ENDS WITH, CONTAINS, LIKE (default: 0.1).
427    pub string_ops: f64,
428    /// Selectivity for IN membership (default: 0.1).
429    pub membership: f64,
430    /// Selectivity for IS NULL (default: 0.05).
431    pub is_null: f64,
432    /// Selectivity for IS NOT NULL (default: 0.95).
433    pub is_not_null: f64,
434    /// Fraction assumed distinct for DISTINCT operations (default: 0.5).
435    pub distinct_fraction: f64,
436}
437
438impl SelectivityConfig {
439    /// Creates a new config with standard database defaults.
440    #[must_use]
441    pub fn new() -> Self {
442        Self {
443            default: 0.1,
444            equality: 0.01,
445            inequality: 0.99,
446            range: 0.33,
447            string_ops: 0.1,
448            membership: 0.1,
449            is_null: 0.05,
450            is_not_null: 0.95,
451            distinct_fraction: 0.5,
452        }
453    }
454}
455
456impl Default for SelectivityConfig {
457    fn default() -> Self {
458        Self::new()
459    }
460}
461
462/// A single estimate-vs-actual observation for analysis.
463#[derive(Debug, Clone)]
464pub struct EstimationEntry {
465    /// Human-readable label for the operator (e.g., "NodeScan(Person)").
466    pub operator: String,
467    /// The cardinality estimate produced by the optimizer.
468    pub estimated: f64,
469    /// The actual row count observed at execution time.
470    pub actual: f64,
471}
472
473impl EstimationEntry {
474    /// Returns the estimation error ratio (actual / estimated).
475    ///
476    /// Values near 1.0 indicate accurate estimates.
477    /// Values > 1.0 indicate underestimation.
478    /// Values < 1.0 indicate overestimation.
479    #[must_use]
480    pub fn error_ratio(&self) -> f64 {
481        if self.estimated.abs() < f64::EPSILON {
482            if self.actual.abs() < f64::EPSILON {
483                1.0
484            } else {
485                f64::INFINITY
486            }
487        } else {
488            self.actual / self.estimated
489        }
490    }
491}
492
493/// Collects estimate vs actual cardinality data for query plan analysis.
494///
495/// After executing a query, call [`record()`](Self::record) for each
496/// operator with its estimated and actual cardinalities. Then use
497/// [`should_replan()`](Self::should_replan) to decide whether the plan
498/// should be re-optimized.
499#[derive(Debug, Clone, Default)]
500pub struct EstimationLog {
501    /// Recorded entries.
502    entries: Vec<EstimationEntry>,
503    /// Error ratio threshold that triggers re-planning (default: 10.0).
504    ///
505    /// If any operator's error ratio exceeds this, `should_replan()` returns true.
506    replan_threshold: f64,
507}
508
509impl EstimationLog {
510    /// Creates a new estimation log with the given re-planning threshold.
511    #[must_use]
512    pub fn new(replan_threshold: f64) -> Self {
513        Self {
514            entries: Vec::new(),
515            replan_threshold,
516        }
517    }
518
519    /// Records an estimate-vs-actual observation.
520    pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
521        self.entries.push(EstimationEntry {
522            operator: operator.into(),
523            estimated,
524            actual,
525        });
526    }
527
528    /// Returns all recorded entries.
529    #[must_use]
530    pub fn entries(&self) -> &[EstimationEntry] {
531        &self.entries
532    }
533
534    /// Returns whether any operator's estimation error exceeds the threshold,
535    /// indicating the plan should be re-optimized.
536    #[must_use]
537    pub fn should_replan(&self) -> bool {
538        self.entries.iter().any(|e| {
539            let ratio = e.error_ratio();
540            ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
541        })
542    }
543
544    /// Returns the maximum error ratio across all entries.
545    #[must_use]
546    pub fn max_error_ratio(&self) -> f64 {
547        self.entries
548            .iter()
549            .map(|e| {
550                let r = e.error_ratio();
551                // Normalize so both over- and under-estimation are > 1.0
552                if r < 1.0 { 1.0 / r } else { r }
553            })
554            .fold(1.0_f64, f64::max)
555    }
556
557    /// Clears all entries.
558    pub fn clear(&mut self) {
559        self.entries.clear();
560    }
561}
562
563/// Cardinality estimator.
564pub struct CardinalityEstimator {
565    /// Statistics for each label/table.
566    table_stats: HashMap<String, TableStats>,
567    /// Default row count for unknown tables.
568    default_row_count: u64,
569    /// Default selectivity for unknown predicates.
570    default_selectivity: f64,
571    /// Average edge fanout (outgoing edges per node).
572    avg_fanout: f64,
573    /// Configurable selectivity defaults.
574    selectivity_config: SelectivityConfig,
575    /// RDF statistics for triple pattern cardinality estimation.
576    rdf_statistics: Option<grafeo_core::statistics::RdfStatistics>,
577}
578
579impl CardinalityEstimator {
580    /// Creates a new cardinality estimator with default settings.
581    #[must_use]
582    pub fn new() -> Self {
583        let config = SelectivityConfig::new();
584        Self {
585            table_stats: HashMap::new(),
586            default_row_count: 1000,
587            default_selectivity: config.default,
588            avg_fanout: 10.0,
589            selectivity_config: config,
590            rdf_statistics: None,
591        }
592    }
593
594    /// Creates a new cardinality estimator with custom selectivity configuration.
595    #[must_use]
596    pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
597        Self {
598            table_stats: HashMap::new(),
599            default_row_count: 1000,
600            default_selectivity: config.default,
601            avg_fanout: 10.0,
602            selectivity_config: config,
603            rdf_statistics: None,
604        }
605    }
606
607    /// Returns the current selectivity configuration.
608    #[must_use]
609    pub fn selectivity_config(&self) -> &SelectivityConfig {
610        &self.selectivity_config
611    }
612
613    /// Creates an estimation log with the default re-planning threshold (10x).
614    #[must_use]
615    pub fn create_estimation_log() -> EstimationLog {
616        EstimationLog::new(10.0)
617    }
618
619    /// Creates a cardinality estimator pre-populated from store statistics.
620    ///
621    /// Maps `LabelStatistics` to `TableStats` and computes the average edge
622    /// fanout from `EdgeTypeStatistics`. Falls back to defaults for any
623    /// missing statistics.
624    #[must_use]
625    pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
626        let mut estimator = Self::new();
627
628        // Use total node count as default for unlabeled scans
629        if stats.total_nodes > 0 {
630            estimator.default_row_count = stats.total_nodes;
631        }
632
633        // Convert label statistics to optimizer table stats
634        for (label, label_stats) in &stats.labels {
635            let mut table_stats = TableStats::new(label_stats.node_count);
636
637            // Map property statistics (distinct count for selectivity estimation)
638            for (prop, col_stats) in &label_stats.properties {
639                let optimizer_col =
640                    ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
641                table_stats = table_stats.with_column(prop, optimizer_col);
642            }
643
644            estimator.add_table_stats(label, table_stats);
645        }
646
647        // Compute average fanout from edge type statistics
648        if !stats.edge_types.is_empty() {
649            let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
650            estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
651        } else if stats.total_nodes > 0 {
652            estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
653        }
654
655        // Clamp fanout to a reasonable minimum
656        if estimator.avg_fanout < 1.0 {
657            estimator.avg_fanout = 1.0;
658        }
659
660        estimator
661    }
662
663    /// Creates a cardinality estimator from RDF statistics.
664    ///
665    /// Uses triple pattern cardinality estimates for `TripleScan` operators
666    /// and join selectivity from per-predicate statistics.
667    #[must_use]
668    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
669        let mut estimator = Self::new();
670        if rdf_stats.total_triples > 0 {
671            estimator.default_row_count = rdf_stats.total_triples;
672        }
673        estimator.rdf_statistics = Some(rdf_stats);
674        estimator
675    }
676
677    /// Adds statistics for a table/label.
678    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
679        self.table_stats.insert(name.to_string(), stats);
680    }
681
682    /// Sets the average edge fanout.
683    pub fn set_avg_fanout(&mut self, fanout: f64) {
684        self.avg_fanout = fanout;
685    }
686
687    /// Estimates the cardinality of a logical operator.
688    #[must_use]
689    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
690        match op {
691            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
692            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
693            LogicalOperator::Project(project) => self.estimate_project(project),
694            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
695            LogicalOperator::Join(join) => self.estimate_join(join),
696            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
697            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
698            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
699            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
700            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
701            LogicalOperator::Return(ret) => self.estimate(&ret.input),
702            LogicalOperator::Empty => 0.0,
703            LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
704            LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
705            LogicalOperator::MultiWayJoin(mwj) => self.estimate_multi_way_join(mwj),
706            LogicalOperator::LeftJoin(lj) => self.estimate_left_join(lj),
707            LogicalOperator::TripleScan(scan) => self.estimate_triple_scan(scan),
708            _ => self.default_row_count as f64,
709        }
710    }
711
712    /// Estimates node scan cardinality.
713    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
714        if let Some(label) = &scan.label
715            && let Some(stats) = self.table_stats.get(label)
716        {
717            return stats.row_count as f64;
718        }
719        // No label filter - scan all nodes
720        self.default_row_count as f64
721    }
722
723    /// Estimates triple scan cardinality using RDF statistics.
724    ///
725    /// If RDF statistics are available, uses the pattern binding (which positions
726    /// are bound vs variable) to produce accurate estimates. Otherwise falls back
727    /// to the default row count.
728    fn estimate_triple_scan(&self, scan: &TripleScanOp) -> f64 {
729        // If there's an input, the triple scan is chained: multiply input cardinality
730        // by the per-row expansion factor.
731        let base = if let Some(ref input) = scan.input {
732            self.estimate(input)
733        } else {
734            1.0
735        };
736
737        let Some(rdf_stats) = &self.rdf_statistics else {
738            return if scan.input.is_some() {
739                base * self.default_row_count as f64
740            } else {
741                self.default_row_count as f64
742            };
743        };
744
745        let subject_bound = matches!(
746            scan.subject,
747            TripleComponent::Iri(_)
748                | TripleComponent::Literal(_)
749                | TripleComponent::LangLiteral { .. }
750        );
751        let object_bound = matches!(
752            scan.object,
753            TripleComponent::Iri(_)
754                | TripleComponent::Literal(_)
755                | TripleComponent::LangLiteral { .. }
756        );
757        let predicate_iri = match &scan.predicate {
758            TripleComponent::Iri(iri) => Some(iri.as_str()),
759            _ => None,
760        };
761
762        let pattern_card = rdf_stats.estimate_triple_pattern_cardinality(
763            subject_bound,
764            predicate_iri,
765            object_bound,
766        );
767
768        if scan.input.is_some() {
769            // Chained scan: each input row expands by the pattern's selectivity
770            let selectivity = if rdf_stats.total_triples > 0 {
771                pattern_card / rdf_stats.total_triples as f64
772            } else {
773                1.0
774            };
775            (base * pattern_card * selectivity).max(1.0)
776        } else {
777            pattern_card.max(1.0)
778        }
779    }
780
781    /// Estimates filter cardinality.
782    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
783        let input_cardinality = self.estimate(&filter.input);
784        let selectivity = self.estimate_selectivity(&filter.predicate);
785        (input_cardinality * selectivity).max(1.0)
786    }
787
788    /// Estimates projection cardinality (same as input).
789    fn estimate_project(&self, project: &ProjectOp) -> f64 {
790        self.estimate(&project.input)
791    }
792
793    /// Estimates expand cardinality.
794    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
795        let input_cardinality = self.estimate(&expand.input);
796
797        // Apply fanout based on edge type
798        let fanout = if !expand.edge_types.is_empty() {
799            // Specific edge type(s) typically have lower fanout
800            self.avg_fanout * 0.5
801        } else {
802            self.avg_fanout
803        };
804
805        // Handle variable-length paths
806        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
807            let min = expand.min_hops as f64;
808            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
809            // Geometric series approximation
810            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
811        } else {
812            fanout
813        };
814
815        (input_cardinality * path_multiplier).max(1.0)
816    }
817
818    /// Estimates join cardinality.
819    fn estimate_join(&self, join: &JoinOp) -> f64 {
820        let left_card = self.estimate(&join.left);
821        let right_card = self.estimate(&join.right);
822
823        match join.join_type {
824            JoinType::Cross => left_card * right_card,
825            JoinType::Inner => {
826                // Assume join selectivity based on conditions
827                let selectivity = if join.conditions.is_empty() {
828                    1.0 // Cross join
829                } else {
830                    // Estimate based on number of conditions
831                    // reason: join condition count is always small (< 100)
832                    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
833                    let exp = join.conditions.len() as i32;
834                    0.1_f64.powi(exp)
835                };
836                (left_card * right_card * selectivity).max(1.0)
837            }
838            JoinType::Left => {
839                // Left join returns at least all left rows
840                let inner_card = self.estimate_join(&JoinOp {
841                    left: join.left.clone(),
842                    right: join.right.clone(),
843                    join_type: JoinType::Inner,
844                    conditions: join.conditions.clone(),
845                });
846                inner_card.max(left_card)
847            }
848            JoinType::Right => {
849                // Right join returns at least all right rows
850                let inner_card = self.estimate_join(&JoinOp {
851                    left: join.left.clone(),
852                    right: join.right.clone(),
853                    join_type: JoinType::Inner,
854                    conditions: join.conditions.clone(),
855                });
856                inner_card.max(right_card)
857            }
858            JoinType::Full => {
859                // Full join returns at least max(left, right)
860                let inner_card = self.estimate_join(&JoinOp {
861                    left: join.left.clone(),
862                    right: join.right.clone(),
863                    join_type: JoinType::Inner,
864                    conditions: join.conditions.clone(),
865                });
866                inner_card.max(left_card.max(right_card))
867            }
868            JoinType::Semi => {
869                // Semi join returns at most left cardinality
870                (left_card * self.default_selectivity).max(1.0)
871            }
872            JoinType::Anti => {
873                // Anti join returns at most left cardinality
874                (left_card * (1.0 - self.default_selectivity)).max(1.0)
875            }
876        }
877    }
878
879    /// Estimates left join cardinality (OPTIONAL MATCH).
880    ///
881    /// A left outer join preserves all left rows, so the output is at least
882    /// `left_cardinality`. When the right side matches, the output may be
883    /// larger (one left row can match multiple right rows).
884    ///
885    /// When the join carries a cross-side condition (null-safe predicates),
886    /// each AND-conjunct reduces the selectivity estimate.
887    fn estimate_left_join(&self, lj: &LeftJoinOp) -> f64 {
888        let left_card = self.estimate(&lj.left);
889        let right_card = self.estimate(&lj.right);
890
891        // Adjust selectivity based on the number of AND conjuncts in the
892        // cross-side condition: each equality reduces match probability.
893        let condition_selectivity = if let Some(cond) = &lj.condition {
894            let n = count_and_conjuncts(cond);
895            // reason: conjunct count is always small (< 100)
896            #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
897            let exp = n as i32;
898            self.default_selectivity.powi(exp)
899        } else {
900            self.default_selectivity
901        };
902
903        // Estimate as inner join cardinality, but guaranteed at least left_card
904        let inner_estimate = left_card * right_card * condition_selectivity;
905        inner_estimate.max(left_card).max(1.0)
906    }
907
908    /// Estimates aggregation cardinality.
909    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
910        let input_cardinality = self.estimate(&agg.input);
911
912        if agg.group_by.is_empty() {
913            // Global aggregation - single row
914            1.0
915        } else {
916            // Group by - estimate distinct groups
917            // Assume each group key reduces cardinality by 10
918            // reason: group-by key count is always small (< 100)
919            #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
920            let exp = agg.group_by.len() as i32;
921            let group_reduction = 10.0_f64.powi(exp);
922            (input_cardinality / group_reduction).max(1.0)
923        }
924    }
925
926    /// Estimates sort cardinality (same as input).
927    fn estimate_sort(&self, sort: &SortOp) -> f64 {
928        self.estimate(&sort.input)
929    }
930
931    /// Estimates distinct cardinality.
932    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
933        let input_cardinality = self.estimate(&distinct.input);
934        (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
935    }
936
937    /// Estimates limit cardinality.
938    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
939        let input_cardinality = self.estimate(&limit.input);
940        limit.count.estimate().min(input_cardinality)
941    }
942
943    /// Estimates skip cardinality.
944    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
945        let input_cardinality = self.estimate(&skip.input);
946        (input_cardinality - skip.count.estimate()).max(0.0)
947    }
948
949    /// Estimates vector scan cardinality.
950    ///
951    /// Vector scan returns at most k results (the k nearest neighbors).
952    /// With similarity/distance filters, it may return fewer.
953    fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
954        let base_k = scan.k as f64;
955
956        // Apply filter selectivity if thresholds are specified
957        let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
958            // Assume 70% of results pass threshold filters
959            0.7
960        } else {
961            1.0
962        };
963
964        (base_k * selectivity).max(1.0)
965    }
966
967    /// Estimates vector join cardinality.
968    ///
969    /// Vector join produces up to k results per input row.
970    fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
971        let input_cardinality = self.estimate(&join.input);
972        let k = join.k as f64;
973
974        // Apply filter selectivity if thresholds are specified
975        let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
976            0.7
977        } else {
978            1.0
979        };
980
981        (input_cardinality * k * selectivity).max(1.0)
982    }
983
984    /// Estimates multi-way join cardinality using the AGM bound heuristic.
985    ///
986    /// For a cyclic join of N relations, the AGM (Atserias-Grohe-Marx) bound
987    /// gives min(cardinality)^(N/2) as a worst-case output size estimate.
988    fn estimate_multi_way_join(&self, mwj: &MultiWayJoinOp) -> f64 {
989        if mwj.inputs.is_empty() {
990            return 0.0;
991        }
992        let cardinalities: Vec<f64> = mwj
993            .inputs
994            .iter()
995            .map(|input| self.estimate(input))
996            .collect();
997        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
998        let n = cardinalities.len() as f64;
999        // AGM bound: min(cardinality)^(n/2)
1000        (min_card.powf(n / 2.0)).max(1.0)
1001    }
1002
1003    /// Estimates the selectivity of a predicate (0.0 to 1.0).
1004    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
1005        match expr {
1006            LogicalExpression::Binary { left, op, right } => {
1007                self.estimate_binary_selectivity(left, *op, right)
1008            }
1009            LogicalExpression::Unary { op, operand } => {
1010                self.estimate_unary_selectivity(*op, operand)
1011            }
1012            LogicalExpression::Literal(value) => {
1013                // Boolean literal
1014                if let grafeo_common::types::Value::Bool(b) = value {
1015                    if *b { 1.0 } else { 0.0 }
1016                } else {
1017                    self.default_selectivity
1018                }
1019            }
1020            _ => self.default_selectivity,
1021        }
1022    }
1023
1024    /// Estimates binary expression selectivity.
1025    fn estimate_binary_selectivity(
1026        &self,
1027        left: &LogicalExpression,
1028        op: BinaryOp,
1029        right: &LogicalExpression,
1030    ) -> f64 {
1031        match op {
1032            // Equality - try histogram-based estimation
1033            BinaryOp::Eq => {
1034                if let Some(selectivity) = self.try_equality_selectivity(left, right) {
1035                    return selectivity;
1036                }
1037                self.selectivity_config.equality
1038            }
1039            // Inequality is very unselective
1040            BinaryOp::Ne => self.selectivity_config.inequality,
1041            // Range predicates - use histogram if available
1042            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
1043                if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
1044                    return selectivity;
1045                }
1046                self.selectivity_config.range
1047            }
1048            // Logical operators - recursively estimate sub-expressions
1049            BinaryOp::And => {
1050                let left_sel = self.estimate_selectivity(left);
1051                let right_sel = self.estimate_selectivity(right);
1052                // AND reduces selectivity (multiply assuming independence)
1053                left_sel * right_sel
1054            }
1055            BinaryOp::Or => {
1056                let left_sel = self.estimate_selectivity(left);
1057                let right_sel = self.estimate_selectivity(right);
1058                // OR: P(A ∪ B) = P(A) + P(B) - P(A ∩ B)
1059                // Assuming independence: P(A ∩ B) = P(A) * P(B)
1060                (left_sel + right_sel - left_sel * right_sel).min(1.0)
1061            }
1062            // String operations
1063            BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
1064                self.selectivity_config.string_ops
1065            }
1066            // Collection membership
1067            BinaryOp::In => self.selectivity_config.membership,
1068            // Other operations
1069            _ => self.default_selectivity,
1070        }
1071    }
1072
1073    /// Tries to estimate equality selectivity using histograms.
1074    fn try_equality_selectivity(
1075        &self,
1076        left: &LogicalExpression,
1077        right: &LogicalExpression,
1078    ) -> Option<f64> {
1079        // Extract property access and literal value
1080        let (label, column, value) = self.extract_column_and_value(left, right)?;
1081
1082        // Get column stats with histogram
1083        let stats = self.get_column_stats(&label, &column)?;
1084
1085        // Try histogram-based estimation
1086        if let Some(ref histogram) = stats.histogram {
1087            return Some(histogram.equality_selectivity(value));
1088        }
1089
1090        // Fall back to distinct count estimation
1091        if stats.distinct_count > 0 {
1092            return Some(1.0 / stats.distinct_count as f64);
1093        }
1094
1095        None
1096    }
1097
1098    /// Tries to estimate range selectivity using histograms.
1099    fn try_range_selectivity(
1100        &self,
1101        left: &LogicalExpression,
1102        op: BinaryOp,
1103        right: &LogicalExpression,
1104    ) -> Option<f64> {
1105        // Extract property access and literal value
1106        let (label, column, value) = self.extract_column_and_value(left, right)?;
1107
1108        // Get column stats
1109        let stats = self.get_column_stats(&label, &column)?;
1110
1111        // Determine the range based on operator
1112        let (lower, upper) = match op {
1113            BinaryOp::Lt => (None, Some(value)),
1114            BinaryOp::Le => (None, Some(value + f64::EPSILON)),
1115            BinaryOp::Gt => (Some(value + f64::EPSILON), None),
1116            BinaryOp::Ge => (Some(value), None),
1117            _ => return None,
1118        };
1119
1120        // Try histogram-based estimation first
1121        if let Some(ref histogram) = stats.histogram {
1122            return Some(histogram.range_selectivity(lower, upper));
1123        }
1124
1125        // Fall back to min/max range estimation
1126        if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
1127            let range = max - min;
1128            if range <= 0.0 {
1129                return Some(1.0);
1130            }
1131
1132            let effective_lower = lower.unwrap_or(min).max(min);
1133            let effective_upper = upper.unwrap_or(max).min(max);
1134            let overlap = (effective_upper - effective_lower).max(0.0);
1135            return Some((overlap / range).clamp(0.0, 1.0));
1136        }
1137
1138        None
1139    }
1140
1141    /// Extracts column information and literal value from a comparison.
1142    ///
1143    /// Returns (label, column_name, numeric_value) if the expression is
1144    /// a comparison between a property access and a numeric literal.
1145    fn extract_column_and_value(
1146        &self,
1147        left: &LogicalExpression,
1148        right: &LogicalExpression,
1149    ) -> Option<(String, String, f64)> {
1150        // Try left as property, right as literal
1151        if let Some(result) = self.try_extract_property_literal(left, right) {
1152            return Some(result);
1153        }
1154
1155        // Try right as property, left as literal
1156        self.try_extract_property_literal(right, left)
1157    }
1158
1159    /// Tries to extract property and literal from a specific ordering.
1160    fn try_extract_property_literal(
1161        &self,
1162        property_expr: &LogicalExpression,
1163        literal_expr: &LogicalExpression,
1164    ) -> Option<(String, String, f64)> {
1165        // Extract property access
1166        let (variable, property) = match property_expr {
1167            LogicalExpression::Property { variable, property } => {
1168                (variable.clone(), property.clone())
1169            }
1170            _ => return None,
1171        };
1172
1173        // Extract numeric literal
1174        let value = match literal_expr {
1175            LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1176            LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1177            _ => return None,
1178        };
1179
1180        // Try to find a label for this variable from table stats
1181        // Use the variable name as a heuristic label lookup
1182        // In practice, the optimizer would track which labels variables are bound to
1183        for label in self.table_stats.keys() {
1184            if let Some(stats) = self.table_stats.get(label)
1185                && stats.columns.contains_key(&property)
1186            {
1187                return Some((label.clone(), property, value));
1188            }
1189        }
1190
1191        // If no stats found but we have the property, return with variable as label
1192        Some((variable, property, value))
1193    }
1194
1195    /// Estimates unary expression selectivity.
1196    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1197        match op {
1198            UnaryOp::Not => 1.0 - self.default_selectivity,
1199            UnaryOp::IsNull => self.selectivity_config.is_null,
1200            UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1201            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
1202        }
1203    }
1204
1205    /// Gets statistics for a column.
1206    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1207        self.table_stats.get(label)?.columns.get(column)
1208    }
1209}
1210
1211impl Default for CardinalityEstimator {
1212    fn default() -> Self {
1213        Self::new()
1214    }
1215}
1216
1217#[cfg(test)]
1218mod tests {
1219    use super::*;
1220    use crate::query::plan::{
1221        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1222        ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1223    };
1224    use grafeo_common::types::Value;
1225
1226    #[test]
1227    fn test_node_scan_with_stats() {
1228        let mut estimator = CardinalityEstimator::new();
1229        estimator.add_table_stats("Person", TableStats::new(5000));
1230
1231        let scan = LogicalOperator::NodeScan(NodeScanOp {
1232            variable: "n".to_string(),
1233            label: Some("Person".to_string()),
1234            input: None,
1235        });
1236
1237        let cardinality = estimator.estimate(&scan);
1238        assert!((cardinality - 5000.0).abs() < 0.001);
1239    }
1240
1241    #[test]
1242    fn test_filter_reduces_cardinality() {
1243        let mut estimator = CardinalityEstimator::new();
1244        estimator.add_table_stats("Person", TableStats::new(1000));
1245
1246        let filter = LogicalOperator::Filter(FilterOp {
1247            predicate: LogicalExpression::Binary {
1248                left: Box::new(LogicalExpression::Property {
1249                    variable: "n".to_string(),
1250                    property: "age".to_string(),
1251                }),
1252                op: BinaryOp::Eq,
1253                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1254            },
1255            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1256                variable: "n".to_string(),
1257                label: Some("Person".to_string()),
1258                input: None,
1259            })),
1260            pushdown_hint: None,
1261        });
1262
1263        let cardinality = estimator.estimate(&filter);
1264        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
1265        assert!(cardinality < 1000.0);
1266        assert!(cardinality >= 1.0);
1267    }
1268
1269    #[test]
1270    fn test_join_cardinality() {
1271        let mut estimator = CardinalityEstimator::new();
1272        estimator.add_table_stats("Person", TableStats::new(1000));
1273        estimator.add_table_stats("Company", TableStats::new(100));
1274
1275        let join = LogicalOperator::Join(JoinOp {
1276            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1277                variable: "p".to_string(),
1278                label: Some("Person".to_string()),
1279                input: None,
1280            })),
1281            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1282                variable: "c".to_string(),
1283                label: Some("Company".to_string()),
1284                input: None,
1285            })),
1286            join_type: JoinType::Inner,
1287            conditions: vec![JoinCondition {
1288                left: LogicalExpression::Property {
1289                    variable: "p".to_string(),
1290                    property: "company_id".to_string(),
1291                },
1292                right: LogicalExpression::Property {
1293                    variable: "c".to_string(),
1294                    property: "id".to_string(),
1295                },
1296            }],
1297        });
1298
1299        let cardinality = estimator.estimate(&join);
1300        // Should be less than cross product
1301        assert!(cardinality < 1000.0 * 100.0);
1302    }
1303
1304    #[test]
1305    fn test_limit_caps_cardinality() {
1306        let mut estimator = CardinalityEstimator::new();
1307        estimator.add_table_stats("Person", TableStats::new(1000));
1308
1309        let limit = LogicalOperator::Limit(LimitOp {
1310            count: 10.into(),
1311            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1312                variable: "n".to_string(),
1313                label: Some("Person".to_string()),
1314                input: None,
1315            })),
1316        });
1317
1318        let cardinality = estimator.estimate(&limit);
1319        assert!((cardinality - 10.0).abs() < 0.001);
1320    }
1321
1322    #[test]
1323    fn test_aggregate_reduces_cardinality() {
1324        let mut estimator = CardinalityEstimator::new();
1325        estimator.add_table_stats("Person", TableStats::new(1000));
1326
1327        // Global aggregation
1328        let global_agg = LogicalOperator::Aggregate(AggregateOp {
1329            group_by: vec![],
1330            aggregates: vec![],
1331            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1332                variable: "n".to_string(),
1333                label: Some("Person".to_string()),
1334                input: None,
1335            })),
1336            having: None,
1337        });
1338
1339        let cardinality = estimator.estimate(&global_agg);
1340        assert!((cardinality - 1.0).abs() < 0.001);
1341
1342        // Group by aggregation
1343        let group_agg = LogicalOperator::Aggregate(AggregateOp {
1344            group_by: vec![LogicalExpression::Property {
1345                variable: "n".to_string(),
1346                property: "city".to_string(),
1347            }],
1348            aggregates: vec![],
1349            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1350                variable: "n".to_string(),
1351                label: Some("Person".to_string()),
1352                input: None,
1353            })),
1354            having: None,
1355        });
1356
1357        let cardinality = estimator.estimate(&group_agg);
1358        // Should be less than input
1359        assert!(cardinality < 1000.0);
1360    }
1361
1362    #[test]
1363    fn test_node_scan_without_stats() {
1364        let estimator = CardinalityEstimator::new();
1365
1366        let scan = LogicalOperator::NodeScan(NodeScanOp {
1367            variable: "n".to_string(),
1368            label: Some("Unknown".to_string()),
1369            input: None,
1370        });
1371
1372        let cardinality = estimator.estimate(&scan);
1373        // Should return default (1000)
1374        assert!((cardinality - 1000.0).abs() < 0.001);
1375    }
1376
1377    #[test]
1378    fn test_node_scan_no_label() {
1379        let estimator = CardinalityEstimator::new();
1380
1381        let scan = LogicalOperator::NodeScan(NodeScanOp {
1382            variable: "n".to_string(),
1383            label: None,
1384            input: None,
1385        });
1386
1387        let cardinality = estimator.estimate(&scan);
1388        // Should scan all nodes (default)
1389        assert!((cardinality - 1000.0).abs() < 0.001);
1390    }
1391
1392    #[test]
1393    fn test_filter_inequality_selectivity() {
1394        let mut estimator = CardinalityEstimator::new();
1395        estimator.add_table_stats("Person", TableStats::new(1000));
1396
1397        let filter = LogicalOperator::Filter(FilterOp {
1398            predicate: LogicalExpression::Binary {
1399                left: Box::new(LogicalExpression::Property {
1400                    variable: "n".to_string(),
1401                    property: "age".to_string(),
1402                }),
1403                op: BinaryOp::Ne,
1404                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1405            },
1406            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1407                variable: "n".to_string(),
1408                label: Some("Person".to_string()),
1409                input: None,
1410            })),
1411            pushdown_hint: None,
1412        });
1413
1414        let cardinality = estimator.estimate(&filter);
1415        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
1416        assert!(cardinality > 900.0);
1417    }
1418
1419    #[test]
1420    fn test_filter_range_selectivity() {
1421        let mut estimator = CardinalityEstimator::new();
1422        estimator.add_table_stats("Person", TableStats::new(1000));
1423
1424        let filter = LogicalOperator::Filter(FilterOp {
1425            predicate: LogicalExpression::Binary {
1426                left: Box::new(LogicalExpression::Property {
1427                    variable: "n".to_string(),
1428                    property: "age".to_string(),
1429                }),
1430                op: BinaryOp::Gt,
1431                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1432            },
1433            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1434                variable: "n".to_string(),
1435                label: Some("Person".to_string()),
1436                input: None,
1437            })),
1438            pushdown_hint: None,
1439        });
1440
1441        let cardinality = estimator.estimate(&filter);
1442        // Range selectivity is 0.33, so 1000 * 0.33 = 330
1443        assert!(cardinality < 500.0);
1444        assert!(cardinality > 100.0);
1445    }
1446
1447    #[test]
1448    fn test_filter_and_selectivity() {
1449        let mut estimator = CardinalityEstimator::new();
1450        estimator.add_table_stats("Person", TableStats::new(1000));
1451
1452        // Test AND with two equality predicates
1453        // Each equality has selectivity 0.01, so AND gives 0.01 * 0.01 = 0.0001
1454        let filter = LogicalOperator::Filter(FilterOp {
1455            predicate: LogicalExpression::Binary {
1456                left: Box::new(LogicalExpression::Binary {
1457                    left: Box::new(LogicalExpression::Property {
1458                        variable: "n".to_string(),
1459                        property: "city".to_string(),
1460                    }),
1461                    op: BinaryOp::Eq,
1462                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1463                }),
1464                op: BinaryOp::And,
1465                right: Box::new(LogicalExpression::Binary {
1466                    left: Box::new(LogicalExpression::Property {
1467                        variable: "n".to_string(),
1468                        property: "age".to_string(),
1469                    }),
1470                    op: BinaryOp::Eq,
1471                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1472                }),
1473            },
1474            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1475                variable: "n".to_string(),
1476                label: Some("Person".to_string()),
1477                input: None,
1478            })),
1479            pushdown_hint: None,
1480        });
1481
1482        let cardinality = estimator.estimate(&filter);
1483        // AND reduces selectivity (multiply): 0.01 * 0.01 = 0.0001
1484        // 1000 * 0.0001 = 0.1, min is 1.0
1485        assert!(cardinality < 100.0);
1486        assert!(cardinality >= 1.0);
1487    }
1488
1489    #[test]
1490    fn test_filter_or_selectivity() {
1491        let mut estimator = CardinalityEstimator::new();
1492        estimator.add_table_stats("Person", TableStats::new(1000));
1493
1494        // Test OR with two equality predicates
1495        // Each equality has selectivity 0.01
1496        // OR gives: 0.01 + 0.01 - (0.01 * 0.01) = 0.0199
1497        let filter = LogicalOperator::Filter(FilterOp {
1498            predicate: LogicalExpression::Binary {
1499                left: Box::new(LogicalExpression::Binary {
1500                    left: Box::new(LogicalExpression::Property {
1501                        variable: "n".to_string(),
1502                        property: "city".to_string(),
1503                    }),
1504                    op: BinaryOp::Eq,
1505                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1506                }),
1507                op: BinaryOp::Or,
1508                right: Box::new(LogicalExpression::Binary {
1509                    left: Box::new(LogicalExpression::Property {
1510                        variable: "n".to_string(),
1511                        property: "city".to_string(),
1512                    }),
1513                    op: BinaryOp::Eq,
1514                    right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1515                }),
1516            },
1517            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1518                variable: "n".to_string(),
1519                label: Some("Person".to_string()),
1520                input: None,
1521            })),
1522            pushdown_hint: None,
1523        });
1524
1525        let cardinality = estimator.estimate(&filter);
1526        // OR: 0.01 + 0.01 - 0.0001 ≈ 0.0199, so 1000 * 0.0199 ≈ 19.9
1527        assert!(cardinality < 100.0);
1528        assert!(cardinality >= 1.0);
1529    }
1530
1531    #[test]
1532    fn test_filter_literal_true() {
1533        let mut estimator = CardinalityEstimator::new();
1534        estimator.add_table_stats("Person", TableStats::new(1000));
1535
1536        let filter = LogicalOperator::Filter(FilterOp {
1537            predicate: LogicalExpression::Literal(Value::Bool(true)),
1538            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1539                variable: "n".to_string(),
1540                label: Some("Person".to_string()),
1541                input: None,
1542            })),
1543            pushdown_hint: None,
1544        });
1545
1546        let cardinality = estimator.estimate(&filter);
1547        // Literal true has selectivity 1.0
1548        assert!((cardinality - 1000.0).abs() < 0.001);
1549    }
1550
1551    #[test]
1552    fn test_filter_literal_false() {
1553        let mut estimator = CardinalityEstimator::new();
1554        estimator.add_table_stats("Person", TableStats::new(1000));
1555
1556        let filter = LogicalOperator::Filter(FilterOp {
1557            predicate: LogicalExpression::Literal(Value::Bool(false)),
1558            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1559                variable: "n".to_string(),
1560                label: Some("Person".to_string()),
1561                input: None,
1562            })),
1563            pushdown_hint: None,
1564        });
1565
1566        let cardinality = estimator.estimate(&filter);
1567        // Literal false has selectivity 0.0, but min is 1.0
1568        assert!((cardinality - 1.0).abs() < 0.001);
1569    }
1570
1571    #[test]
1572    fn test_unary_not_selectivity() {
1573        let mut estimator = CardinalityEstimator::new();
1574        estimator.add_table_stats("Person", TableStats::new(1000));
1575
1576        let filter = LogicalOperator::Filter(FilterOp {
1577            predicate: LogicalExpression::Unary {
1578                op: UnaryOp::Not,
1579                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1580            },
1581            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1582                variable: "n".to_string(),
1583                label: Some("Person".to_string()),
1584                input: None,
1585            })),
1586            pushdown_hint: None,
1587        });
1588
1589        let cardinality = estimator.estimate(&filter);
1590        // NOT inverts selectivity
1591        assert!(cardinality < 1000.0);
1592    }
1593
1594    #[test]
1595    fn test_unary_is_null_selectivity() {
1596        let mut estimator = CardinalityEstimator::new();
1597        estimator.add_table_stats("Person", TableStats::new(1000));
1598
1599        let filter = LogicalOperator::Filter(FilterOp {
1600            predicate: LogicalExpression::Unary {
1601                op: UnaryOp::IsNull,
1602                operand: Box::new(LogicalExpression::Variable("x".to_string())),
1603            },
1604            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1605                variable: "n".to_string(),
1606                label: Some("Person".to_string()),
1607                input: None,
1608            })),
1609            pushdown_hint: None,
1610        });
1611
1612        let cardinality = estimator.estimate(&filter);
1613        // IS NULL has selectivity 0.05
1614        assert!(cardinality < 100.0);
1615    }
1616
1617    #[test]
1618    fn test_expand_cardinality() {
1619        let mut estimator = CardinalityEstimator::new();
1620        estimator.add_table_stats("Person", TableStats::new(100));
1621
1622        let expand = LogicalOperator::Expand(ExpandOp {
1623            from_variable: "a".to_string(),
1624            to_variable: "b".to_string(),
1625            edge_variable: None,
1626            direction: ExpandDirection::Outgoing,
1627            edge_types: vec![],
1628            min_hops: 1,
1629            max_hops: Some(1),
1630            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1631                variable: "a".to_string(),
1632                label: Some("Person".to_string()),
1633                input: None,
1634            })),
1635            path_alias: None,
1636            path_mode: PathMode::Walk,
1637        });
1638
1639        let cardinality = estimator.estimate(&expand);
1640        // Expand multiplies by fanout (10)
1641        assert!(cardinality > 100.0);
1642    }
1643
1644    #[test]
1645    fn test_expand_with_edge_type_filter() {
1646        let mut estimator = CardinalityEstimator::new();
1647        estimator.add_table_stats("Person", TableStats::new(100));
1648
1649        let expand = LogicalOperator::Expand(ExpandOp {
1650            from_variable: "a".to_string(),
1651            to_variable: "b".to_string(),
1652            edge_variable: None,
1653            direction: ExpandDirection::Outgoing,
1654            edge_types: vec!["KNOWS".to_string()],
1655            min_hops: 1,
1656            max_hops: Some(1),
1657            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1658                variable: "a".to_string(),
1659                label: Some("Person".to_string()),
1660                input: None,
1661            })),
1662            path_alias: None,
1663            path_mode: PathMode::Walk,
1664        });
1665
1666        let cardinality = estimator.estimate(&expand);
1667        // With edge type, fanout is reduced by half
1668        assert!(cardinality > 100.0);
1669    }
1670
1671    #[test]
1672    fn test_expand_variable_length() {
1673        let mut estimator = CardinalityEstimator::new();
1674        estimator.add_table_stats("Person", TableStats::new(100));
1675
1676        let expand = LogicalOperator::Expand(ExpandOp {
1677            from_variable: "a".to_string(),
1678            to_variable: "b".to_string(),
1679            edge_variable: None,
1680            direction: ExpandDirection::Outgoing,
1681            edge_types: vec![],
1682            min_hops: 1,
1683            max_hops: Some(3),
1684            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1685                variable: "a".to_string(),
1686                label: Some("Person".to_string()),
1687                input: None,
1688            })),
1689            path_alias: None,
1690            path_mode: PathMode::Walk,
1691        });
1692
1693        let cardinality = estimator.estimate(&expand);
1694        // Variable length path has much higher cardinality
1695        assert!(cardinality > 500.0);
1696    }
1697
1698    #[test]
1699    fn test_join_cross_product() {
1700        let mut estimator = CardinalityEstimator::new();
1701        estimator.add_table_stats("Person", TableStats::new(100));
1702        estimator.add_table_stats("Company", TableStats::new(50));
1703
1704        let join = LogicalOperator::Join(JoinOp {
1705            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1706                variable: "p".to_string(),
1707                label: Some("Person".to_string()),
1708                input: None,
1709            })),
1710            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1711                variable: "c".to_string(),
1712                label: Some("Company".to_string()),
1713                input: None,
1714            })),
1715            join_type: JoinType::Cross,
1716            conditions: vec![],
1717        });
1718
1719        let cardinality = estimator.estimate(&join);
1720        // Cross join = 100 * 50 = 5000
1721        assert!((cardinality - 5000.0).abs() < 0.001);
1722    }
1723
1724    #[test]
1725    fn test_join_left_outer() {
1726        let mut estimator = CardinalityEstimator::new();
1727        estimator.add_table_stats("Person", TableStats::new(1000));
1728        estimator.add_table_stats("Company", TableStats::new(10));
1729
1730        let join = LogicalOperator::Join(JoinOp {
1731            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1732                variable: "p".to_string(),
1733                label: Some("Person".to_string()),
1734                input: None,
1735            })),
1736            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1737                variable: "c".to_string(),
1738                label: Some("Company".to_string()),
1739                input: None,
1740            })),
1741            join_type: JoinType::Left,
1742            conditions: vec![JoinCondition {
1743                left: LogicalExpression::Variable("p".to_string()),
1744                right: LogicalExpression::Variable("c".to_string()),
1745            }],
1746        });
1747
1748        let cardinality = estimator.estimate(&join);
1749        // Left join returns at least all left rows
1750        assert!(cardinality >= 1000.0);
1751    }
1752
1753    #[test]
1754    fn test_join_semi() {
1755        let mut estimator = CardinalityEstimator::new();
1756        estimator.add_table_stats("Person", TableStats::new(1000));
1757        estimator.add_table_stats("Company", TableStats::new(100));
1758
1759        let join = LogicalOperator::Join(JoinOp {
1760            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1761                variable: "p".to_string(),
1762                label: Some("Person".to_string()),
1763                input: None,
1764            })),
1765            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1766                variable: "c".to_string(),
1767                label: Some("Company".to_string()),
1768                input: None,
1769            })),
1770            join_type: JoinType::Semi,
1771            conditions: vec![],
1772        });
1773
1774        let cardinality = estimator.estimate(&join);
1775        // Semi join returns at most left cardinality
1776        assert!(cardinality <= 1000.0);
1777    }
1778
1779    #[test]
1780    fn test_join_anti() {
1781        let mut estimator = CardinalityEstimator::new();
1782        estimator.add_table_stats("Person", TableStats::new(1000));
1783        estimator.add_table_stats("Company", TableStats::new(100));
1784
1785        let join = LogicalOperator::Join(JoinOp {
1786            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1787                variable: "p".to_string(),
1788                label: Some("Person".to_string()),
1789                input: None,
1790            })),
1791            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1792                variable: "c".to_string(),
1793                label: Some("Company".to_string()),
1794                input: None,
1795            })),
1796            join_type: JoinType::Anti,
1797            conditions: vec![],
1798        });
1799
1800        let cardinality = estimator.estimate(&join);
1801        // Anti join returns at most left cardinality
1802        assert!(cardinality <= 1000.0);
1803        assert!(cardinality >= 1.0);
1804    }
1805
1806    #[test]
1807    fn test_project_preserves_cardinality() {
1808        let mut estimator = CardinalityEstimator::new();
1809        estimator.add_table_stats("Person", TableStats::new(1000));
1810
1811        let project = LogicalOperator::Project(ProjectOp {
1812            projections: vec![Projection {
1813                expression: LogicalExpression::Variable("n".to_string()),
1814                alias: None,
1815            }],
1816            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1817                variable: "n".to_string(),
1818                label: Some("Person".to_string()),
1819                input: None,
1820            })),
1821            pass_through_input: false,
1822        });
1823
1824        let cardinality = estimator.estimate(&project);
1825        assert!((cardinality - 1000.0).abs() < 0.001);
1826    }
1827
1828    #[test]
1829    fn test_sort_preserves_cardinality() {
1830        let mut estimator = CardinalityEstimator::new();
1831        estimator.add_table_stats("Person", TableStats::new(1000));
1832
1833        let sort = LogicalOperator::Sort(SortOp {
1834            keys: vec![SortKey {
1835                expression: LogicalExpression::Variable("n".to_string()),
1836                order: SortOrder::Ascending,
1837                nulls: None,
1838            }],
1839            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1840                variable: "n".to_string(),
1841                label: Some("Person".to_string()),
1842                input: None,
1843            })),
1844        });
1845
1846        let cardinality = estimator.estimate(&sort);
1847        assert!((cardinality - 1000.0).abs() < 0.001);
1848    }
1849
1850    #[test]
1851    fn test_distinct_reduces_cardinality() {
1852        let mut estimator = CardinalityEstimator::new();
1853        estimator.add_table_stats("Person", TableStats::new(1000));
1854
1855        let distinct = LogicalOperator::Distinct(DistinctOp {
1856            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1857                variable: "n".to_string(),
1858                label: Some("Person".to_string()),
1859                input: None,
1860            })),
1861            columns: None,
1862        });
1863
1864        let cardinality = estimator.estimate(&distinct);
1865        // Distinct assumes 50% unique
1866        assert!((cardinality - 500.0).abs() < 0.001);
1867    }
1868
1869    #[test]
1870    fn test_skip_reduces_cardinality() {
1871        let mut estimator = CardinalityEstimator::new();
1872        estimator.add_table_stats("Person", TableStats::new(1000));
1873
1874        let skip = LogicalOperator::Skip(SkipOp {
1875            count: 100.into(),
1876            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1877                variable: "n".to_string(),
1878                label: Some("Person".to_string()),
1879                input: None,
1880            })),
1881        });
1882
1883        let cardinality = estimator.estimate(&skip);
1884        assert!((cardinality - 900.0).abs() < 0.001);
1885    }
1886
1887    #[test]
1888    fn test_return_preserves_cardinality() {
1889        let mut estimator = CardinalityEstimator::new();
1890        estimator.add_table_stats("Person", TableStats::new(1000));
1891
1892        let ret = LogicalOperator::Return(ReturnOp {
1893            items: vec![ReturnItem {
1894                expression: LogicalExpression::Variable("n".to_string()),
1895                alias: None,
1896            }],
1897            distinct: false,
1898            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1899                variable: "n".to_string(),
1900                label: Some("Person".to_string()),
1901                input: None,
1902            })),
1903        });
1904
1905        let cardinality = estimator.estimate(&ret);
1906        assert!((cardinality - 1000.0).abs() < 0.001);
1907    }
1908
1909    #[test]
1910    fn test_empty_cardinality() {
1911        let estimator = CardinalityEstimator::new();
1912        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1913        assert!((cardinality).abs() < 0.001);
1914    }
1915
1916    #[test]
1917    fn test_table_stats_with_column() {
1918        let stats = TableStats::new(1000).with_column(
1919            "age",
1920            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1921        );
1922
1923        assert_eq!(stats.row_count, 1000);
1924        let col = stats.columns.get("age").unwrap();
1925        assert_eq!(col.distinct_count, 50);
1926        assert_eq!(col.null_count, 10);
1927        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1928        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1929    }
1930
1931    #[test]
1932    fn test_estimator_default() {
1933        let estimator = CardinalityEstimator::default();
1934        let scan = LogicalOperator::NodeScan(NodeScanOp {
1935            variable: "n".to_string(),
1936            label: None,
1937            input: None,
1938        });
1939        let cardinality = estimator.estimate(&scan);
1940        assert!((cardinality - 1000.0).abs() < 0.001);
1941    }
1942
1943    #[test]
1944    fn test_set_avg_fanout() {
1945        let mut estimator = CardinalityEstimator::new();
1946        estimator.add_table_stats("Person", TableStats::new(100));
1947        estimator.set_avg_fanout(5.0);
1948
1949        let expand = LogicalOperator::Expand(ExpandOp {
1950            from_variable: "a".to_string(),
1951            to_variable: "b".to_string(),
1952            edge_variable: None,
1953            direction: ExpandDirection::Outgoing,
1954            edge_types: vec![],
1955            min_hops: 1,
1956            max_hops: Some(1),
1957            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1958                variable: "a".to_string(),
1959                label: Some("Person".to_string()),
1960                input: None,
1961            })),
1962            path_alias: None,
1963            path_mode: PathMode::Walk,
1964        });
1965
1966        let cardinality = estimator.estimate(&expand);
1967        // With fanout 5: 100 * 5 = 500
1968        assert!((cardinality - 500.0).abs() < 0.001);
1969    }
1970
1971    #[test]
1972    fn test_multiple_group_by_keys_reduce_cardinality() {
1973        // The current implementation uses a simplified model where more group by keys
1974        // results in greater reduction (dividing by 10^num_keys). This is a simplification
1975        // that works for most cases where group by keys are correlated.
1976        let mut estimator = CardinalityEstimator::new();
1977        estimator.add_table_stats("Person", TableStats::new(10000));
1978
1979        let single_group = LogicalOperator::Aggregate(AggregateOp {
1980            group_by: vec![LogicalExpression::Property {
1981                variable: "n".to_string(),
1982                property: "city".to_string(),
1983            }],
1984            aggregates: vec![],
1985            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1986                variable: "n".to_string(),
1987                label: Some("Person".to_string()),
1988                input: None,
1989            })),
1990            having: None,
1991        });
1992
1993        let multi_group = LogicalOperator::Aggregate(AggregateOp {
1994            group_by: vec![
1995                LogicalExpression::Property {
1996                    variable: "n".to_string(),
1997                    property: "city".to_string(),
1998                },
1999                LogicalExpression::Property {
2000                    variable: "n".to_string(),
2001                    property: "country".to_string(),
2002                },
2003            ],
2004            aggregates: vec![],
2005            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2006                variable: "n".to_string(),
2007                label: Some("Person".to_string()),
2008                input: None,
2009            })),
2010            having: None,
2011        });
2012
2013        let single_card = estimator.estimate(&single_group);
2014        let multi_card = estimator.estimate(&multi_group);
2015
2016        // Both should reduce cardinality from input
2017        assert!(single_card < 10000.0);
2018        assert!(multi_card < 10000.0);
2019        // Both should be at least 1
2020        assert!(single_card >= 1.0);
2021        assert!(multi_card >= 1.0);
2022    }
2023
2024    // ============= Histogram Tests =============
2025
2026    #[test]
2027    fn test_histogram_build_uniform() {
2028        // Build histogram from uniformly distributed data
2029        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2030        let histogram = EquiDepthHistogram::build(&values, 10);
2031
2032        assert_eq!(histogram.num_buckets(), 10);
2033        assert_eq!(histogram.total_rows(), 100);
2034
2035        // Each bucket should have approximately 10 rows
2036        for bucket in histogram.buckets() {
2037            assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
2038        }
2039    }
2040
2041    #[test]
2042    fn test_histogram_build_skewed() {
2043        // Build histogram from skewed data (many small values, few large)
2044        let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
2045        values.extend((0..20).map(|i| 1000.0 + i as f64));
2046        let histogram = EquiDepthHistogram::build(&values, 5);
2047
2048        assert_eq!(histogram.num_buckets(), 5);
2049        assert_eq!(histogram.total_rows(), 100);
2050
2051        // Each bucket should have ~20 rows despite skewed data
2052        for bucket in histogram.buckets() {
2053            assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
2054        }
2055    }
2056
2057    #[test]
2058    fn test_histogram_range_selectivity_full() {
2059        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2060        let histogram = EquiDepthHistogram::build(&values, 10);
2061
2062        // Full range should have selectivity ~1.0
2063        let selectivity = histogram.range_selectivity(None, None);
2064        assert!((selectivity - 1.0).abs() < 0.01);
2065    }
2066
2067    #[test]
2068    fn test_histogram_range_selectivity_half() {
2069        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2070        let histogram = EquiDepthHistogram::build(&values, 10);
2071
2072        // Values >= 50 should be ~50% (half the data)
2073        let selectivity = histogram.range_selectivity(Some(50.0), None);
2074        assert!(selectivity > 0.4 && selectivity < 0.6);
2075    }
2076
2077    #[test]
2078    fn test_histogram_range_selectivity_quarter() {
2079        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2080        let histogram = EquiDepthHistogram::build(&values, 10);
2081
2082        // Values < 25 should be ~25%
2083        let selectivity = histogram.range_selectivity(None, Some(25.0));
2084        assert!(selectivity > 0.2 && selectivity < 0.3);
2085    }
2086
2087    #[test]
2088    fn test_histogram_equality_selectivity() {
2089        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2090        let histogram = EquiDepthHistogram::build(&values, 10);
2091
2092        // Equality on 100 distinct values should be ~1%
2093        let selectivity = histogram.equality_selectivity(50.0);
2094        assert!(selectivity > 0.005 && selectivity < 0.02);
2095    }
2096
2097    #[test]
2098    fn test_histogram_empty() {
2099        let histogram = EquiDepthHistogram::build(&[], 10);
2100
2101        assert_eq!(histogram.num_buckets(), 0);
2102        assert_eq!(histogram.total_rows(), 0);
2103
2104        // Default selectivity for empty histogram
2105        let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
2106        assert!((selectivity - 0.33).abs() < 0.01);
2107    }
2108
2109    #[test]
2110    fn test_histogram_bucket_overlap() {
2111        let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
2112
2113        // Full overlap
2114        assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
2115
2116        // Half overlap (lower half)
2117        assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
2118
2119        // Half overlap (upper half)
2120        assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
2121
2122        // No overlap (below)
2123        assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
2124
2125        // No overlap (above)
2126        assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
2127    }
2128
2129    #[test]
2130    fn test_column_stats_from_values() {
2131        let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2132        let stats = ColumnStats::from_values(values, 4);
2133
2134        assert_eq!(stats.distinct_count, 5); // 10, 20, 30, 40, 50
2135        assert!(stats.min_value.is_some());
2136        assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2137        assert!(stats.max_value.is_some());
2138        assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2139        assert!(stats.histogram.is_some());
2140    }
2141
2142    #[test]
2143    fn test_column_stats_with_histogram_builder() {
2144        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2145        let histogram = EquiDepthHistogram::build(&values, 10);
2146
2147        let stats = ColumnStats::new(100)
2148            .with_range(0.0, 99.0)
2149            .with_histogram(histogram);
2150
2151        assert!(stats.histogram.is_some());
2152        assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2153    }
2154
2155    #[test]
2156    fn test_filter_with_histogram_stats() {
2157        let mut estimator = CardinalityEstimator::new();
2158
2159        // Create stats with histogram for age column
2160        let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2161        let histogram = EquiDepthHistogram::build(&age_values, 10);
2162        let age_stats = ColumnStats::new(62)
2163            .with_range(18.0, 79.0)
2164            .with_histogram(histogram);
2165
2166        estimator.add_table_stats(
2167            "Person",
2168            TableStats::new(1000).with_column("age", age_stats),
2169        );
2170
2171        // Filter: age > 50
2172        // Age range is 18-79, so >50 is about (79-50)/(79-18) = 29/61 ≈ 47.5%
2173        let filter = LogicalOperator::Filter(FilterOp {
2174            predicate: LogicalExpression::Binary {
2175                left: Box::new(LogicalExpression::Property {
2176                    variable: "n".to_string(),
2177                    property: "age".to_string(),
2178                }),
2179                op: BinaryOp::Gt,
2180                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2181            },
2182            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2183                variable: "n".to_string(),
2184                label: Some("Person".to_string()),
2185                input: None,
2186            })),
2187            pushdown_hint: None,
2188        });
2189
2190        let cardinality = estimator.estimate(&filter);
2191
2192        // With histogram, should get more accurate estimate than default 0.33
2193        // Expected: ~47.5% of 1000 = ~475
2194        assert!(cardinality > 300.0 && cardinality < 600.0);
2195    }
2196
2197    #[test]
2198    fn test_filter_equality_with_histogram() {
2199        let mut estimator = CardinalityEstimator::new();
2200
2201        // Create stats with histogram
2202        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2203        let histogram = EquiDepthHistogram::build(&values, 10);
2204        let stats = ColumnStats::new(100)
2205            .with_range(0.0, 99.0)
2206            .with_histogram(histogram);
2207
2208        estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2209
2210        // Filter: value = 50
2211        let filter = LogicalOperator::Filter(FilterOp {
2212            predicate: LogicalExpression::Binary {
2213                left: Box::new(LogicalExpression::Property {
2214                    variable: "d".to_string(),
2215                    property: "value".to_string(),
2216                }),
2217                op: BinaryOp::Eq,
2218                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2219            },
2220            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2221                variable: "d".to_string(),
2222                label: Some("Data".to_string()),
2223                input: None,
2224            })),
2225            pushdown_hint: None,
2226        });
2227
2228        let cardinality = estimator.estimate(&filter);
2229
2230        // With 100 distinct values, selectivity should be ~1%
2231        // 1000 * 0.01 = 10
2232        assert!((1.0..50.0).contains(&cardinality));
2233    }
2234
2235    #[test]
2236    fn test_histogram_min_max() {
2237        let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2238        let histogram = EquiDepthHistogram::build(&values, 2);
2239
2240        assert_eq!(histogram.min_value(), Some(5.0));
2241        // Max is the upper bound of the last bucket
2242        assert!(histogram.max_value().is_some());
2243    }
2244
2245    // ==================== SelectivityConfig Tests ====================
2246
2247    #[test]
2248    fn test_selectivity_config_defaults() {
2249        let config = SelectivityConfig::new();
2250        assert!((config.default - 0.1).abs() < f64::EPSILON);
2251        assert!((config.equality - 0.01).abs() < f64::EPSILON);
2252        assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2253        assert!((config.range - 0.33).abs() < f64::EPSILON);
2254        assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2255        assert!((config.membership - 0.1).abs() < f64::EPSILON);
2256        assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2257        assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2258        assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2259    }
2260
2261    #[test]
2262    fn test_custom_selectivity_config() {
2263        let config = SelectivityConfig {
2264            equality: 0.05,
2265            range: 0.25,
2266            ..SelectivityConfig::new()
2267        };
2268        let estimator = CardinalityEstimator::with_selectivity_config(config);
2269        assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2270        assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2271    }
2272
2273    #[test]
2274    fn test_custom_selectivity_affects_estimation() {
2275        // Default: equality = 0.01 → 1000 * 0.01 = 10
2276        let mut default_est = CardinalityEstimator::new();
2277        default_est.add_table_stats("Person", TableStats::new(1000));
2278
2279        let filter = LogicalOperator::Filter(FilterOp {
2280            predicate: LogicalExpression::Binary {
2281                left: Box::new(LogicalExpression::Property {
2282                    variable: "n".to_string(),
2283                    property: "name".to_string(),
2284                }),
2285                op: BinaryOp::Eq,
2286                right: Box::new(LogicalExpression::Literal(Value::String("Alix".into()))),
2287            },
2288            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2289                variable: "n".to_string(),
2290                label: Some("Person".to_string()),
2291                input: None,
2292            })),
2293            pushdown_hint: None,
2294        });
2295
2296        let default_card = default_est.estimate(&filter);
2297
2298        // Custom: equality = 0.2 → 1000 * 0.2 = 200
2299        let config = SelectivityConfig {
2300            equality: 0.2,
2301            ..SelectivityConfig::new()
2302        };
2303        let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2304        custom_est.add_table_stats("Person", TableStats::new(1000));
2305
2306        let custom_card = custom_est.estimate(&filter);
2307
2308        assert!(custom_card > default_card);
2309        assert!((custom_card - 200.0).abs() < 1.0);
2310    }
2311
2312    #[test]
2313    fn test_custom_range_selectivity() {
2314        let config = SelectivityConfig {
2315            range: 0.5,
2316            ..SelectivityConfig::new()
2317        };
2318        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2319        estimator.add_table_stats("Person", TableStats::new(1000));
2320
2321        let filter = LogicalOperator::Filter(FilterOp {
2322            predicate: LogicalExpression::Binary {
2323                left: Box::new(LogicalExpression::Property {
2324                    variable: "n".to_string(),
2325                    property: "age".to_string(),
2326                }),
2327                op: BinaryOp::Gt,
2328                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2329            },
2330            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2331                variable: "n".to_string(),
2332                label: Some("Person".to_string()),
2333                input: None,
2334            })),
2335            pushdown_hint: None,
2336        });
2337
2338        let cardinality = estimator.estimate(&filter);
2339        // 1000 * 0.5 = 500
2340        assert!((cardinality - 500.0).abs() < 1.0);
2341    }
2342
2343    #[test]
2344    fn test_custom_distinct_fraction() {
2345        let config = SelectivityConfig {
2346            distinct_fraction: 0.8,
2347            ..SelectivityConfig::new()
2348        };
2349        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2350        estimator.add_table_stats("Person", TableStats::new(1000));
2351
2352        let distinct = LogicalOperator::Distinct(DistinctOp {
2353            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2354                variable: "n".to_string(),
2355                label: Some("Person".to_string()),
2356                input: None,
2357            })),
2358            columns: None,
2359        });
2360
2361        let cardinality = estimator.estimate(&distinct);
2362        // 1000 * 0.8 = 800
2363        assert!((cardinality - 800.0).abs() < 1.0);
2364    }
2365
2366    // ==================== EstimationLog Tests ====================
2367
2368    #[test]
2369    fn test_estimation_log_basic() {
2370        let mut log = EstimationLog::new(10.0);
2371        log.record("NodeScan(Person)", 1000.0, 1200.0);
2372        log.record("Filter(age > 30)", 100.0, 90.0);
2373
2374        assert_eq!(log.entries().len(), 2);
2375        assert!(!log.should_replan()); // 1.2x and 0.9x are within 10x threshold
2376    }
2377
2378    #[test]
2379    fn test_estimation_log_triggers_replan() {
2380        let mut log = EstimationLog::new(10.0);
2381        log.record("NodeScan(Person)", 100.0, 5000.0); // 50x underestimate
2382
2383        assert!(log.should_replan());
2384    }
2385
2386    #[test]
2387    fn test_estimation_log_overestimate_triggers_replan() {
2388        let mut log = EstimationLog::new(5.0);
2389        log.record("Filter", 1000.0, 100.0); // 10x overestimate → ratio = 0.1
2390
2391        assert!(log.should_replan()); // 0.1 < 1/5 = 0.2
2392    }
2393
2394    #[test]
2395    fn test_estimation_entry_error_ratio() {
2396        let entry = EstimationEntry {
2397            operator: "test".into(),
2398            estimated: 100.0,
2399            actual: 200.0,
2400        };
2401        assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2402
2403        let perfect = EstimationEntry {
2404            operator: "test".into(),
2405            estimated: 100.0,
2406            actual: 100.0,
2407        };
2408        assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2409
2410        let zero_est = EstimationEntry {
2411            operator: "test".into(),
2412            estimated: 0.0,
2413            actual: 0.0,
2414        };
2415        assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2416    }
2417
2418    #[test]
2419    fn test_estimation_log_max_error_ratio() {
2420        let mut log = EstimationLog::new(10.0);
2421        log.record("A", 100.0, 300.0); // 3x
2422        log.record("B", 100.0, 50.0); // 2x (normalized: 1/0.5 = 2)
2423        log.record("C", 100.0, 100.0); // 1x
2424
2425        assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2426    }
2427
2428    #[test]
2429    fn test_estimation_log_clear() {
2430        let mut log = EstimationLog::new(10.0);
2431        log.record("A", 100.0, 100.0);
2432        assert_eq!(log.entries().len(), 1);
2433
2434        log.clear();
2435        assert!(log.entries().is_empty());
2436        assert!(!log.should_replan());
2437    }
2438
2439    #[test]
2440    fn test_create_estimation_log() {
2441        let log = CardinalityEstimator::create_estimation_log();
2442        assert!(log.entries().is_empty());
2443        assert!(!log.should_replan());
2444    }
2445
2446    #[test]
2447    fn test_equality_selectivity_empty_histogram() {
2448        let hist = EquiDepthHistogram::new(vec![]);
2449        // Empty histogram returns fixed fallback
2450        assert_eq!(hist.equality_selectivity(5.0), 0.01);
2451    }
2452
2453    #[test]
2454    fn test_equality_selectivity_value_in_bucket() {
2455        let values: Vec<f64> = (1..=10).map(|i| i as f64).collect();
2456        let hist = EquiDepthHistogram::build(&values, 2);
2457        let sel = hist.equality_selectivity(3.0);
2458        assert!(sel > 0.0);
2459        assert!(sel <= 1.0);
2460    }
2461
2462    #[test]
2463    fn test_equality_selectivity_value_outside_all_buckets() {
2464        let values: Vec<f64> = (1..=10).map(|i| i as f64).collect();
2465        let hist = EquiDepthHistogram::build(&values, 2);
2466        // Value far outside range
2467        let sel = hist.equality_selectivity(9999.0);
2468        assert_eq!(sel, 0.001);
2469    }
2470
2471    #[test]
2472    fn test_histogram_min_max_empty() {
2473        let hist = EquiDepthHistogram::new(vec![]);
2474        assert_eq!(hist.min_value(), None);
2475        assert_eq!(hist.max_value(), None);
2476    }
2477
2478    #[test]
2479    fn test_histogram_min_max_single_bucket() {
2480        let hist = EquiDepthHistogram::new(vec![HistogramBucket::new(1.0, 10.0, 5, 5)]);
2481        assert_eq!(hist.min_value(), Some(1.0));
2482        assert_eq!(hist.max_value(), Some(10.0));
2483    }
2484
2485    #[test]
2486    fn test_histogram_min_max_multi_bucket() {
2487        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0];
2488        let hist = EquiDepthHistogram::build(&values, 3);
2489        let min = hist.min_value().unwrap();
2490        let max = hist.max_value().unwrap();
2491        assert!((min - 1.0).abs() < 1e-9, "min should be 1.0, got {min}");
2492        assert!(max >= 20.0, "max should be >= last value, got {max}");
2493    }
2494
2495    #[test]
2496    fn test_count_and_conjuncts_single_expression() {
2497        use crate::query::plan::LogicalExpression;
2498        let expr = LogicalExpression::Literal(Value::Bool(true));
2499        assert_eq!(count_and_conjuncts(&expr), 1);
2500    }
2501
2502    #[test]
2503    fn test_count_and_conjuncts_flat_and() {
2504        use crate::query::plan::{BinaryOp, LogicalExpression};
2505        let expr = LogicalExpression::Binary {
2506            left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
2507            op: BinaryOp::And,
2508            right: Box::new(LogicalExpression::Literal(Value::Bool(false))),
2509        };
2510        assert_eq!(count_and_conjuncts(&expr), 2);
2511    }
2512
2513    #[test]
2514    fn test_count_and_conjuncts_nested_and() {
2515        use crate::query::plan::{BinaryOp, LogicalExpression};
2516        let ab = LogicalExpression::Binary {
2517            left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
2518            op: BinaryOp::And,
2519            right: Box::new(LogicalExpression::Literal(Value::Bool(false))),
2520        };
2521        let cd = LogicalExpression::Binary {
2522            left: Box::new(LogicalExpression::Literal(Value::Int64(1))),
2523            op: BinaryOp::And,
2524            right: Box::new(LogicalExpression::Literal(Value::Int64(2))),
2525        };
2526        let expr = LogicalExpression::Binary {
2527            left: Box::new(ab),
2528            op: BinaryOp::And,
2529            right: Box::new(cd),
2530        };
2531        assert_eq!(count_and_conjuncts(&expr), 4);
2532    }
2533
2534    #[test]
2535    fn test_count_distinct_empty() {
2536        assert_eq!(count_distinct(&[]), 0);
2537    }
2538
2539    #[test]
2540    fn test_count_distinct_all_unique() {
2541        assert_eq!(count_distinct(&[1.0, 2.0, 3.0, 4.0]), 4);
2542    }
2543
2544    #[test]
2545    fn test_count_distinct_with_duplicates() {
2546        assert_eq!(count_distinct(&[1.0, 1.0, 2.0, 2.0, 3.0]), 3);
2547    }
2548
2549    #[test]
2550    fn test_count_distinct_all_same() {
2551        assert_eq!(count_distinct(&[5.0, 5.0, 5.0]), 1);
2552    }
2553
2554    #[test]
2555    fn test_count_distinct_single_value() {
2556        assert_eq!(count_distinct(&[42.0]), 1);
2557    }
2558}