Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4//!
5//! # Equi-Depth Histograms
6//!
7//! This module provides equi-depth histogram support for accurate selectivity
8//! estimation of range predicates. Unlike equi-width histograms that divide
9//! the value range into equal-sized buckets, equi-depth histograms divide
10//! the data into buckets with approximately equal numbers of rows.
11//!
12//! Benefits:
13//! - Better estimates for skewed data distributions
14//! - More accurate range selectivity than assuming uniform distribution
15//! - Adaptive to actual data characteristics
16
17use crate::query::plan::{
18    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp, LimitOp,
19    LogicalExpression, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, SkipOp, SortOp,
20    TextScanOp, TripleComponent, TripleScanOp, UnaryOp, VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24/// A bucket in an equi-depth histogram.
25///
26/// Each bucket represents a range of values and the frequency of rows
27/// falling within that range. In an equi-depth histogram, all buckets
28/// contain approximately the same number of rows.
29#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31    /// Lower bound of the bucket (inclusive).
32    pub lower_bound: f64,
33    /// Upper bound of the bucket (exclusive, except for the last bucket).
34    pub upper_bound: f64,
35    /// Number of rows in this bucket.
36    pub frequency: u64,
37    /// Number of distinct values in this bucket.
38    pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42    /// Creates a new histogram bucket.
43    #[must_use]
44    pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45        Self {
46            lower_bound,
47            upper_bound,
48            frequency,
49            distinct_count,
50        }
51    }
52
53    /// Returns the width of this bucket.
54    #[must_use]
55    pub fn width(&self) -> f64 {
56        self.upper_bound - self.lower_bound
57    }
58
59    /// Checks if a value falls within this bucket.
60    #[must_use]
61    pub fn contains(&self, value: f64) -> bool {
62        value >= self.lower_bound && value < self.upper_bound
63    }
64
65    /// Returns the fraction of this bucket covered by the given range.
66    #[must_use]
67    pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68        let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69        let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71        let bucket_width = self.width();
72        if bucket_width <= 0.0 {
73            return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74                1.0
75            } else {
76                0.0
77            };
78        }
79
80        let overlap = (effective_upper - effective_lower).max(0.0);
81        (overlap / bucket_width).min(1.0)
82    }
83}
84
85/// An equi-depth histogram for selectivity estimation.
86///
87/// Equi-depth histograms partition data into buckets where each bucket
88/// contains approximately the same number of rows. This provides more
89/// accurate selectivity estimates than assuming uniform distribution,
90/// especially for skewed data.
91///
92/// # Example
93///
94/// ```no_run
95/// use grafeo_engine::query::optimizer::cardinality::EquiDepthHistogram;
96///
97/// // Build a histogram from sorted values
98/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0];
99/// let histogram = EquiDepthHistogram::build(&values, 4);
100///
101/// // Estimate selectivity for age > 25
102/// let selectivity = histogram.range_selectivity(Some(25.0), None);
103/// ```
104#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106    /// The histogram buckets, sorted by lower_bound.
107    buckets: Vec<HistogramBucket>,
108    /// Total number of rows represented by this histogram.
109    total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113    /// Creates a new histogram from pre-built buckets.
114    #[must_use]
115    pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116        let total_rows = buckets.iter().map(|b| b.frequency).sum();
117        Self {
118            buckets,
119            total_rows,
120        }
121    }
122
123    /// Builds an equi-depth histogram from a sorted slice of values.
124    ///
125    /// # Arguments
126    /// * `values` - A sorted slice of numeric values
127    /// * `num_buckets` - The desired number of buckets
128    ///
129    /// # Returns
130    /// An equi-depth histogram with approximately equal row counts per bucket.
131    #[must_use]
132    pub fn build(values: &[f64], num_buckets: usize) -> Self {
133        if values.is_empty() || num_buckets == 0 {
134            return Self {
135                buckets: Vec::new(),
136                total_rows: 0,
137            };
138        }
139
140        let num_buckets = num_buckets.min(values.len());
141        let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142        let mut buckets = Vec::with_capacity(num_buckets);
143
144        let mut start_idx = 0;
145        while start_idx < values.len() {
146            let end_idx = (start_idx + rows_per_bucket).min(values.len());
147            let lower_bound = values[start_idx];
148            let upper_bound = if end_idx < values.len() {
149                values[end_idx]
150            } else {
151                // For the last bucket, extend slightly beyond the max value
152                values[end_idx - 1] + 1.0
153            };
154
155            // Count distinct values in this bucket
156            let bucket_values = &values[start_idx..end_idx];
157            let distinct_count = count_distinct(bucket_values);
158
159            buckets.push(HistogramBucket::new(
160                lower_bound,
161                upper_bound,
162                (end_idx - start_idx) as u64,
163                distinct_count,
164            ));
165
166            start_idx = end_idx;
167        }
168
169        Self::new(buckets)
170    }
171
172    /// Returns the number of buckets in this histogram.
173    #[must_use]
174    pub fn num_buckets(&self) -> usize {
175        self.buckets.len()
176    }
177
178    /// Returns the total number of rows represented.
179    #[must_use]
180    pub fn total_rows(&self) -> u64 {
181        self.total_rows
182    }
183
184    /// Returns the histogram buckets.
185    #[must_use]
186    pub fn buckets(&self) -> &[HistogramBucket] {
187        &self.buckets
188    }
189
190    /// Estimates selectivity for a range predicate.
191    ///
192    /// # Arguments
193    /// * `lower` - Lower bound (None for unbounded)
194    /// * `upper` - Upper bound (None for unbounded)
195    ///
196    /// # Returns
197    /// Estimated fraction of rows matching the range (0.0 to 1.0).
198    #[must_use]
199    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200        if self.buckets.is_empty() || self.total_rows == 0 {
201            return 0.33; // Default fallback
202        }
203
204        let mut matching_rows = 0.0;
205
206        for bucket in &self.buckets {
207            // Check if this bucket overlaps with the range
208            let bucket_lower = bucket.lower_bound;
209            let bucket_upper = bucket.upper_bound;
210
211            // Skip buckets entirely outside the range
212            if let Some(l) = lower
213                && bucket_upper <= l
214            {
215                continue;
216            }
217            if let Some(u) = upper
218                && bucket_lower >= u
219            {
220                continue;
221            }
222
223            // Calculate the fraction of this bucket covered by the range
224            let overlap = bucket.overlap_fraction(lower, upper);
225            matching_rows += overlap * bucket.frequency as f64;
226        }
227
228        (matching_rows / self.total_rows as f64).clamp(0.0, 1.0)
229    }
230
231    /// Estimates selectivity for an equality predicate.
232    ///
233    /// Uses the distinct count within matching buckets for better accuracy.
234    #[must_use]
235    pub fn equality_selectivity(&self, value: f64) -> f64 {
236        if self.buckets.is_empty() || self.total_rows == 0 {
237            return 0.01; // Default fallback
238        }
239
240        // Find the bucket containing this value
241        for bucket in &self.buckets {
242            if bucket.contains(value) {
243                // Assume uniform distribution within bucket
244                if bucket.distinct_count > 0 {
245                    return (bucket.frequency as f64
246                        / bucket.distinct_count as f64
247                        / self.total_rows as f64)
248                        .min(1.0);
249                }
250            }
251        }
252
253        // Value not in any bucket - very low selectivity
254        0.001
255    }
256
257    /// Gets the minimum value in the histogram.
258    #[must_use]
259    pub fn min_value(&self) -> Option<f64> {
260        self.buckets.first().map(|b| b.lower_bound)
261    }
262
263    /// Gets the maximum value in the histogram.
264    #[must_use]
265    pub fn max_value(&self) -> Option<f64> {
266        self.buckets.last().map(|b| b.upper_bound)
267    }
268}
269
270/// Counts distinct values in a sorted slice.
271/// Counts the number of top-level AND conjuncts in an expression.
272///
273/// For `(A AND B) AND (C AND D)` returns 4; for a single non-AND expression
274/// returns 1. Used to estimate selectivity of multi-predicate join conditions.
275fn count_and_conjuncts(expr: &LogicalExpression) -> usize {
276    match expr {
277        LogicalExpression::Binary {
278            op: BinaryOp::And,
279            left,
280            right,
281        } => count_and_conjuncts(left) + count_and_conjuncts(right),
282        _ => 1,
283    }
284}
285
286fn count_distinct(sorted_values: &[f64]) -> u64 {
287    if sorted_values.is_empty() {
288        return 0;
289    }
290
291    let mut count = 1u64;
292    let mut prev = sorted_values[0];
293
294    for &val in &sorted_values[1..] {
295        if (val - prev).abs() > f64::EPSILON {
296            count += 1;
297            prev = val;
298        }
299    }
300
301    count
302}
303
304/// Statistics for a table/label.
305#[derive(Debug, Clone)]
306pub struct TableStats {
307    /// Total number of rows.
308    pub row_count: u64,
309    /// Column statistics.
310    pub columns: HashMap<String, ColumnStats>,
311}
312
313impl TableStats {
314    /// Creates new table statistics.
315    #[must_use]
316    pub fn new(row_count: u64) -> Self {
317        Self {
318            row_count,
319            columns: HashMap::new(),
320        }
321    }
322
323    /// Adds column statistics.
324    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
325        self.columns.insert(name.to_string(), stats);
326        self
327    }
328}
329
330/// Statistics for a column.
331#[derive(Debug, Clone)]
332pub struct ColumnStats {
333    /// Number of distinct values.
334    pub distinct_count: u64,
335    /// Number of null values.
336    pub null_count: u64,
337    /// Minimum value (if orderable).
338    pub min_value: Option<f64>,
339    /// Maximum value (if orderable).
340    pub max_value: Option<f64>,
341    /// Equi-depth histogram for accurate selectivity estimation.
342    pub histogram: Option<EquiDepthHistogram>,
343}
344
345impl ColumnStats {
346    /// Creates new column statistics.
347    #[must_use]
348    pub fn new(distinct_count: u64) -> Self {
349        Self {
350            distinct_count,
351            null_count: 0,
352            min_value: None,
353            max_value: None,
354            histogram: None,
355        }
356    }
357
358    /// Sets the null count.
359    #[must_use]
360    pub fn with_nulls(mut self, null_count: u64) -> Self {
361        self.null_count = null_count;
362        self
363    }
364
365    /// Sets the min/max range.
366    #[must_use]
367    pub fn with_range(mut self, min: f64, max: f64) -> Self {
368        self.min_value = Some(min);
369        self.max_value = Some(max);
370        self
371    }
372
373    /// Sets the equi-depth histogram for this column.
374    #[must_use]
375    pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
376        self.histogram = Some(histogram);
377        self
378    }
379
380    /// Builds column statistics with histogram from raw values.
381    ///
382    /// This is a convenience method that computes all statistics from the data.
383    ///
384    /// # Arguments
385    /// * `values` - The column values (will be sorted internally)
386    /// * `num_buckets` - Number of histogram buckets to create
387    #[must_use]
388    pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
389        if values.is_empty() {
390            return Self::new(0);
391        }
392
393        // Sort values for histogram building
394        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
395
396        let min = values.first().copied();
397        let max = values.last().copied();
398        let distinct_count = count_distinct(&values);
399        let histogram = EquiDepthHistogram::build(&values, num_buckets);
400
401        Self {
402            distinct_count,
403            null_count: 0,
404            min_value: min,
405            max_value: max,
406            histogram: Some(histogram),
407        }
408    }
409}
410
411/// Configurable selectivity defaults for cardinality estimation.
412///
413/// Controls the assumed selectivity for various predicate types when
414/// histogram or column statistics are unavailable. Adjusting these
415/// values can improve plan quality for workloads with known skew.
416#[derive(Debug, Clone)]
417pub struct SelectivityConfig {
418    /// Selectivity for unknown predicates (default: 0.1).
419    pub default: f64,
420    /// Selectivity for equality predicates without stats (default: 0.01).
421    pub equality: f64,
422    /// Selectivity for inequality predicates (default: 0.99).
423    pub inequality: f64,
424    /// Selectivity for range predicates without stats (default: 0.33).
425    pub range: f64,
426    /// Selectivity for string operations: STARTS WITH, ENDS WITH, CONTAINS, LIKE (default: 0.1).
427    pub string_ops: f64,
428    /// Selectivity for IN membership (default: 0.1).
429    pub membership: f64,
430    /// Selectivity for IS NULL (default: 0.05).
431    pub is_null: f64,
432    /// Selectivity for IS NOT NULL (default: 0.95).
433    pub is_not_null: f64,
434    /// Fraction assumed distinct for DISTINCT operations (default: 0.5).
435    pub distinct_fraction: f64,
436}
437
438impl SelectivityConfig {
439    /// Creates a new config with standard database defaults.
440    #[must_use]
441    pub fn new() -> Self {
442        Self {
443            default: 0.1,
444            equality: 0.01,
445            inequality: 0.99,
446            range: 0.33,
447            string_ops: 0.1,
448            membership: 0.1,
449            is_null: 0.05,
450            is_not_null: 0.95,
451            distinct_fraction: 0.5,
452        }
453    }
454}
455
456impl Default for SelectivityConfig {
457    fn default() -> Self {
458        Self::new()
459    }
460}
461
462/// A single estimate-vs-actual observation for analysis.
463#[derive(Debug, Clone)]
464pub struct EstimationEntry {
465    /// Human-readable label for the operator (e.g., "NodeScan(Person)").
466    pub operator: String,
467    /// The cardinality estimate produced by the optimizer.
468    pub estimated: f64,
469    /// The actual row count observed at execution time.
470    pub actual: f64,
471}
472
473impl EstimationEntry {
474    /// Returns the estimation error ratio (actual / estimated).
475    ///
476    /// Values near 1.0 indicate accurate estimates.
477    /// Values > 1.0 indicate underestimation.
478    /// Values < 1.0 indicate overestimation.
479    #[must_use]
480    pub fn error_ratio(&self) -> f64 {
481        if self.estimated.abs() < f64::EPSILON {
482            if self.actual.abs() < f64::EPSILON {
483                1.0
484            } else {
485                f64::INFINITY
486            }
487        } else {
488            self.actual / self.estimated
489        }
490    }
491}
492
493/// Collects estimate vs actual cardinality data for query plan analysis.
494///
495/// After executing a query, call [`record()`](Self::record) for each
496/// operator with its estimated and actual cardinalities. Then use
497/// [`should_replan()`](Self::should_replan) to decide whether the plan
498/// should be re-optimized.
499#[derive(Debug, Clone, Default)]
500pub struct EstimationLog {
501    /// Recorded entries.
502    entries: Vec<EstimationEntry>,
503    /// Error ratio threshold that triggers re-planning (default: 10.0).
504    ///
505    /// If any operator's error ratio exceeds this, `should_replan()` returns true.
506    replan_threshold: f64,
507}
508
509impl EstimationLog {
510    /// Creates a new estimation log with the given re-planning threshold.
511    #[must_use]
512    pub fn new(replan_threshold: f64) -> Self {
513        Self {
514            entries: Vec::new(),
515            replan_threshold,
516        }
517    }
518
519    /// Records an estimate-vs-actual observation.
520    pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
521        self.entries.push(EstimationEntry {
522            operator: operator.into(),
523            estimated,
524            actual,
525        });
526    }
527
528    /// Returns all recorded entries.
529    #[must_use]
530    pub fn entries(&self) -> &[EstimationEntry] {
531        &self.entries
532    }
533
534    /// Returns whether any operator's estimation error exceeds the threshold,
535    /// indicating the plan should be re-optimized.
536    #[must_use]
537    pub fn should_replan(&self) -> bool {
538        self.entries.iter().any(|e| {
539            let ratio = e.error_ratio();
540            ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
541        })
542    }
543
544    /// Returns the maximum error ratio across all entries.
545    #[must_use]
546    pub fn max_error_ratio(&self) -> f64 {
547        self.entries
548            .iter()
549            .map(|e| {
550                let r = e.error_ratio();
551                // Normalize so both over- and under-estimation are > 1.0
552                if r < 1.0 { 1.0 / r } else { r }
553            })
554            .fold(1.0_f64, f64::max)
555    }
556
557    /// Clears all entries.
558    pub fn clear(&mut self) {
559        self.entries.clear();
560    }
561}
562
563/// Cardinality estimator.
564pub struct CardinalityEstimator {
565    /// Statistics for each label/table.
566    table_stats: HashMap<String, TableStats>,
567    /// Default row count for unknown tables.
568    default_row_count: u64,
569    /// Default selectivity for unknown predicates.
570    default_selectivity: f64,
571    /// Average edge fanout (outgoing edges per node).
572    avg_fanout: f64,
573    /// Configurable selectivity defaults.
574    selectivity_config: SelectivityConfig,
575    /// RDF statistics for triple pattern cardinality estimation.
576    rdf_statistics: Option<grafeo_core::statistics::RdfStatistics>,
577}
578
579impl CardinalityEstimator {
580    /// Creates a new cardinality estimator with default settings.
581    #[must_use]
582    pub fn new() -> Self {
583        let config = SelectivityConfig::new();
584        Self {
585            table_stats: HashMap::new(),
586            default_row_count: 1000,
587            default_selectivity: config.default,
588            avg_fanout: 10.0,
589            selectivity_config: config,
590            rdf_statistics: None,
591        }
592    }
593
594    /// Creates a new cardinality estimator with custom selectivity configuration.
595    #[must_use]
596    pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
597        Self {
598            table_stats: HashMap::new(),
599            default_row_count: 1000,
600            default_selectivity: config.default,
601            avg_fanout: 10.0,
602            selectivity_config: config,
603            rdf_statistics: None,
604        }
605    }
606
607    /// Returns the current selectivity configuration.
608    #[must_use]
609    pub fn selectivity_config(&self) -> &SelectivityConfig {
610        &self.selectivity_config
611    }
612
613    /// Creates an estimation log with the default re-planning threshold (10x).
614    #[must_use]
615    pub fn create_estimation_log() -> EstimationLog {
616        EstimationLog::new(10.0)
617    }
618
619    /// Creates a cardinality estimator pre-populated from store statistics.
620    ///
621    /// Maps `LabelStatistics` to `TableStats` and computes the average edge
622    /// fanout from `EdgeTypeStatistics`. Falls back to defaults for any
623    /// missing statistics.
624    #[must_use]
625    pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
626        let mut estimator = Self::new();
627
628        // Use total node count as default for unlabeled scans
629        if stats.total_nodes > 0 {
630            estimator.default_row_count = stats.total_nodes;
631        }
632
633        // Convert label statistics to optimizer table stats
634        for (label, label_stats) in &stats.labels {
635            let mut table_stats = TableStats::new(label_stats.node_count);
636
637            // Map property statistics (distinct count for selectivity estimation)
638            for (prop, col_stats) in &label_stats.properties {
639                let optimizer_col =
640                    ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
641                table_stats = table_stats.with_column(prop, optimizer_col);
642            }
643
644            estimator.add_table_stats(label, table_stats);
645        }
646
647        // Compute average fanout from edge type statistics
648        if !stats.edge_types.is_empty() {
649            let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
650            estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
651        } else if stats.total_nodes > 0 {
652            estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
653        }
654
655        // Clamp fanout to a reasonable minimum
656        if estimator.avg_fanout < 1.0 {
657            estimator.avg_fanout = 1.0;
658        }
659
660        estimator
661    }
662
663    /// Creates a cardinality estimator from RDF statistics.
664    ///
665    /// Uses triple pattern cardinality estimates for `TripleScan` operators
666    /// and join selectivity from per-predicate statistics.
667    #[must_use]
668    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
669        let mut estimator = Self::new();
670        if rdf_stats.total_triples > 0 {
671            estimator.default_row_count = rdf_stats.total_triples;
672        }
673        estimator.rdf_statistics = Some(rdf_stats);
674        estimator
675    }
676
677    /// Adds statistics for a table/label.
678    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
679        self.table_stats.insert(name.to_string(), stats);
680    }
681
682    /// Sets the average edge fanout.
683    pub fn set_avg_fanout(&mut self, fanout: f64) {
684        self.avg_fanout = fanout;
685    }
686
687    /// Estimates the cardinality of a logical operator.
688    #[must_use]
689    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
690        match op {
691            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
692            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
693            LogicalOperator::Project(project) => self.estimate_project(project),
694            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
695            LogicalOperator::Join(join) => self.estimate_join(join),
696            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
697            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
698            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
699            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
700            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
701            LogicalOperator::Return(ret) => self.estimate(&ret.input),
702            LogicalOperator::Empty => 0.0,
703            LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
704            LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
705            LogicalOperator::MultiWayJoin(mwj) => self.estimate_multi_way_join(mwj),
706            LogicalOperator::LeftJoin(lj) => self.estimate_left_join(lj),
707            LogicalOperator::TripleScan(scan) => self.estimate_triple_scan(scan),
708            LogicalOperator::TextScan(scan) => self.estimate_text_scan(scan),
709            _ => self.default_row_count as f64,
710        }
711    }
712
713    /// Estimates node scan cardinality.
714    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
715        if let Some(label) = &scan.label
716            && let Some(stats) = self.table_stats.get(label)
717        {
718            return stats.row_count as f64;
719        }
720        // No label filter - scan all nodes
721        self.default_row_count as f64
722    }
723
724    /// Estimates triple scan cardinality using RDF statistics.
725    ///
726    /// If RDF statistics are available, uses the pattern binding (which positions
727    /// are bound vs variable) to produce accurate estimates. Otherwise falls back
728    /// to the default row count.
729    fn estimate_triple_scan(&self, scan: &TripleScanOp) -> f64 {
730        // If there's an input, the triple scan is chained: multiply input cardinality
731        // by the per-row expansion factor.
732        let base = if let Some(ref input) = scan.input {
733            self.estimate(input)
734        } else {
735            1.0
736        };
737
738        let Some(rdf_stats) = &self.rdf_statistics else {
739            return if scan.input.is_some() {
740                base * self.default_row_count as f64
741            } else {
742                self.default_row_count as f64
743            };
744        };
745
746        let subject_bound = matches!(
747            scan.subject,
748            TripleComponent::Iri(_)
749                | TripleComponent::Literal(_)
750                | TripleComponent::LangLiteral { .. }
751        );
752        let object_bound = matches!(
753            scan.object,
754            TripleComponent::Iri(_)
755                | TripleComponent::Literal(_)
756                | TripleComponent::LangLiteral { .. }
757        );
758        let predicate_iri = match &scan.predicate {
759            TripleComponent::Iri(iri) => Some(iri.as_str()),
760            _ => None,
761        };
762
763        let pattern_card = rdf_stats.estimate_triple_pattern_cardinality(
764            subject_bound,
765            predicate_iri,
766            object_bound,
767        );
768
769        if scan.input.is_some() {
770            // Chained scan: each input row expands by the pattern's selectivity
771            let selectivity = if rdf_stats.total_triples > 0 {
772                pattern_card / rdf_stats.total_triples as f64
773            } else {
774                1.0
775            };
776            (base * pattern_card * selectivity).max(1.0)
777        } else {
778            pattern_card.max(1.0)
779        }
780    }
781
782    /// Estimates filter cardinality.
783    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
784        let input_cardinality = self.estimate(&filter.input);
785        let selectivity = self.estimate_selectivity(&filter.predicate);
786        (input_cardinality * selectivity).max(1.0)
787    }
788
789    /// Estimates projection cardinality (same as input).
790    fn estimate_project(&self, project: &ProjectOp) -> f64 {
791        self.estimate(&project.input)
792    }
793
794    /// Estimates expand cardinality.
795    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
796        let input_cardinality = self.estimate(&expand.input);
797
798        // Apply fanout based on edge type
799        let fanout = if !expand.edge_types.is_empty() {
800            // Specific edge type(s) typically have lower fanout
801            self.avg_fanout * 0.5
802        } else {
803            self.avg_fanout
804        };
805
806        // Handle variable-length paths
807        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
808            let min = expand.min_hops as f64;
809            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
810            // Geometric series approximation
811            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
812        } else {
813            fanout
814        };
815
816        (input_cardinality * path_multiplier).max(1.0)
817    }
818
819    /// Estimates join cardinality.
820    fn estimate_join(&self, join: &JoinOp) -> f64 {
821        let left_card = self.estimate(&join.left);
822        let right_card = self.estimate(&join.right);
823
824        match join.join_type {
825            JoinType::Cross => left_card * right_card,
826            JoinType::Inner => {
827                // Assume join selectivity based on conditions
828                let selectivity = if join.conditions.is_empty() {
829                    1.0 // Cross join
830                } else {
831                    // Estimate based on number of conditions
832                    // reason: join condition count is always small (< 100)
833                    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
834                    let exp = join.conditions.len() as i32;
835                    0.1_f64.powi(exp)
836                };
837                (left_card * right_card * selectivity).max(1.0)
838            }
839            JoinType::Left => {
840                // Left join returns at least all left rows
841                let inner_card = self.estimate_join(&JoinOp {
842                    left: join.left.clone(),
843                    right: join.right.clone(),
844                    join_type: JoinType::Inner,
845                    conditions: join.conditions.clone(),
846                });
847                inner_card.max(left_card)
848            }
849            JoinType::Right => {
850                // Right join returns at least all right rows
851                let inner_card = self.estimate_join(&JoinOp {
852                    left: join.left.clone(),
853                    right: join.right.clone(),
854                    join_type: JoinType::Inner,
855                    conditions: join.conditions.clone(),
856                });
857                inner_card.max(right_card)
858            }
859            JoinType::Full => {
860                // Full join returns at least max(left, right)
861                let inner_card = self.estimate_join(&JoinOp {
862                    left: join.left.clone(),
863                    right: join.right.clone(),
864                    join_type: JoinType::Inner,
865                    conditions: join.conditions.clone(),
866                });
867                inner_card.max(left_card.max(right_card))
868            }
869            JoinType::Semi => {
870                // Semi join returns at most left cardinality
871                (left_card * self.default_selectivity).max(1.0)
872            }
873            JoinType::Anti => {
874                // Anti join returns at most left cardinality
875                (left_card * (1.0 - self.default_selectivity)).max(1.0)
876            }
877        }
878    }
879
880    /// Estimates left join cardinality (OPTIONAL MATCH).
881    ///
882    /// A left outer join preserves all left rows, so the output is at least
883    /// `left_cardinality`. When the right side matches, the output may be
884    /// larger (one left row can match multiple right rows).
885    ///
886    /// When the join carries a cross-side condition (null-safe predicates),
887    /// each AND-conjunct reduces the selectivity estimate.
888    fn estimate_left_join(&self, lj: &LeftJoinOp) -> f64 {
889        let left_card = self.estimate(&lj.left);
890        let right_card = self.estimate(&lj.right);
891
892        // Adjust selectivity based on the number of AND conjuncts in the
893        // cross-side condition: each equality reduces match probability.
894        let condition_selectivity = if let Some(cond) = &lj.condition {
895            let n = count_and_conjuncts(cond);
896            // reason: conjunct count is always small (< 100)
897            #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
898            let exp = n as i32;
899            self.default_selectivity.powi(exp)
900        } else {
901            self.default_selectivity
902        };
903
904        // Estimate as inner join cardinality, but guaranteed at least left_card
905        let inner_estimate = left_card * right_card * condition_selectivity;
906        inner_estimate.max(left_card).max(1.0)
907    }
908
909    /// Estimates aggregation cardinality.
910    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
911        let input_cardinality = self.estimate(&agg.input);
912
913        if agg.group_by.is_empty() {
914            // Global aggregation - single row
915            1.0
916        } else {
917            // Group by - estimate distinct groups
918            // Assume each group key reduces cardinality by 10
919            // reason: group-by key count is always small (< 100)
920            #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
921            let exp = agg.group_by.len() as i32;
922            let group_reduction = 10.0_f64.powi(exp);
923            (input_cardinality / group_reduction).max(1.0)
924        }
925    }
926
927    /// Estimates sort cardinality (same as input).
928    fn estimate_sort(&self, sort: &SortOp) -> f64 {
929        self.estimate(&sort.input)
930    }
931
932    /// Estimates distinct cardinality.
933    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
934        let input_cardinality = self.estimate(&distinct.input);
935        (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
936    }
937
938    /// Estimates limit cardinality.
939    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
940        let input_cardinality = self.estimate(&limit.input);
941        limit.count.estimate().min(input_cardinality)
942    }
943
944    /// Estimates skip cardinality.
945    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
946        let input_cardinality = self.estimate(&skip.input);
947        (input_cardinality - skip.count.estimate()).max(0.0)
948    }
949
950    /// Estimates vector scan cardinality.
951    ///
952    /// Vector scan returns at most k results (the k nearest neighbors).
953    /// With similarity/distance filters, it may return fewer.
954    fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
955        if let Some(k) = scan.k {
956            // Top-k mode: at most k results, reduced by filter selectivity
957            let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
958                0.7 // Assume 70% of results pass threshold filters
959            } else {
960                1.0
961            };
962            (k as f64 * selectivity).max(1.0)
963        } else {
964            // Threshold-only mode: estimate 20% of indexed nodes match
965            let base = scan
966                .label
967                .as_deref()
968                .and_then(|l| self.table_stats.get(l))
969                .map_or(self.default_row_count as f64, |s| s.row_count as f64);
970            (base * 0.2).max(1.0)
971        }
972    }
973
974    /// Estimates text scan cardinality using BM25 index.
975    ///
976    /// In top-k mode, returns at most k results. In threshold mode, assumes
977    /// 10% of the label's documents match the query.
978    fn estimate_text_scan(&self, scan: &TextScanOp) -> f64 {
979        if let Some(k) = scan.k {
980            // Top-k mode: at most k results
981            return k as f64;
982        }
983        if scan.threshold.is_some() {
984            // Threshold mode: estimate 10% of indexed documents match
985            let default_selectivity = 0.1;
986            let base = if let Some(stats) = self.table_stats.get(&scan.label) {
987                stats.row_count as f64
988            } else {
989                self.default_row_count as f64
990            };
991            return (base * default_selectivity).max(1.0);
992        }
993        // Default top-k(100) — matches executor fallback
994        100.0
995    }
996
997    /// Estimates vector join cardinality.
998    ///
999    /// Vector join produces up to k results per input row.
1000    fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
1001        let input_cardinality = self.estimate(&join.input);
1002        let k = join.k as f64;
1003
1004        // Apply filter selectivity if thresholds are specified
1005        let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
1006            0.7
1007        } else {
1008            1.0
1009        };
1010
1011        (input_cardinality * k * selectivity).max(1.0)
1012    }
1013
1014    /// Estimates multi-way join cardinality using the AGM bound heuristic.
1015    ///
1016    /// For a cyclic join of N relations, the AGM (Atserias-Grohe-Marx) bound
1017    /// gives min(cardinality)^(N/2) as a worst-case output size estimate.
1018    fn estimate_multi_way_join(&self, mwj: &MultiWayJoinOp) -> f64 {
1019        if mwj.inputs.is_empty() {
1020            return 0.0;
1021        }
1022        let cardinalities: Vec<f64> = mwj
1023            .inputs
1024            .iter()
1025            .map(|input| self.estimate(input))
1026            .collect();
1027        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
1028        let n = cardinalities.len() as f64;
1029        // AGM bound: min(cardinality)^(n/2)
1030        (min_card.powf(n / 2.0)).max(1.0)
1031    }
1032
1033    /// Estimates the selectivity of a predicate (0.0 to 1.0).
1034    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
1035        match expr {
1036            LogicalExpression::Binary { left, op, right } => {
1037                self.estimate_binary_selectivity(left, *op, right)
1038            }
1039            LogicalExpression::Unary { op, operand } => {
1040                self.estimate_unary_selectivity(*op, operand)
1041            }
1042            LogicalExpression::Literal(value) => {
1043                // Boolean literal
1044                if let grafeo_common::types::Value::Bool(b) = value {
1045                    if *b { 1.0 } else { 0.0 }
1046                } else {
1047                    self.default_selectivity
1048                }
1049            }
1050            _ => self.default_selectivity,
1051        }
1052    }
1053
1054    /// Estimates binary expression selectivity.
1055    fn estimate_binary_selectivity(
1056        &self,
1057        left: &LogicalExpression,
1058        op: BinaryOp,
1059        right: &LogicalExpression,
1060    ) -> f64 {
1061        match op {
1062            // Equality - try histogram-based estimation
1063            BinaryOp::Eq => {
1064                if let Some(selectivity) = self.try_equality_selectivity(left, right) {
1065                    return selectivity;
1066                }
1067                self.selectivity_config.equality
1068            }
1069            // Inequality is very unselective
1070            BinaryOp::Ne => self.selectivity_config.inequality,
1071            // Range predicates - use histogram if available
1072            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
1073                if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
1074                    return selectivity;
1075                }
1076                self.selectivity_config.range
1077            }
1078            // Logical operators - recursively estimate sub-expressions
1079            BinaryOp::And => {
1080                let left_sel = self.estimate_selectivity(left);
1081                let right_sel = self.estimate_selectivity(right);
1082                // AND reduces selectivity (multiply assuming independence)
1083                left_sel * right_sel
1084            }
1085            BinaryOp::Or => {
1086                let left_sel = self.estimate_selectivity(left);
1087                let right_sel = self.estimate_selectivity(right);
1088                // OR: P(A ∪ B) = P(A) + P(B) - P(A ∩ B)
1089                // Assuming independence: P(A ∩ B) = P(A) * P(B)
1090                (left_sel + right_sel - left_sel * right_sel).min(1.0)
1091            }
1092            // String operations
1093            BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
1094                self.selectivity_config.string_ops
1095            }
1096            // Collection membership
1097            BinaryOp::In => self.selectivity_config.membership,
1098            // Other operations
1099            _ => self.default_selectivity,
1100        }
1101    }
1102
1103    /// Tries to estimate equality selectivity using histograms.
1104    fn try_equality_selectivity(
1105        &self,
1106        left: &LogicalExpression,
1107        right: &LogicalExpression,
1108    ) -> Option<f64> {
1109        // Extract property access and literal value
1110        let (label, column, value) = self.extract_column_and_value(left, right)?;
1111
1112        // Get column stats with histogram
1113        let stats = self.get_column_stats(&label, &column)?;
1114
1115        // Try histogram-based estimation
1116        if let Some(ref histogram) = stats.histogram {
1117            return Some(histogram.equality_selectivity(value));
1118        }
1119
1120        // Fall back to distinct count estimation
1121        if stats.distinct_count > 0 {
1122            return Some(1.0 / stats.distinct_count as f64);
1123        }
1124
1125        None
1126    }
1127
1128    /// Tries to estimate range selectivity using histograms.
1129    fn try_range_selectivity(
1130        &self,
1131        left: &LogicalExpression,
1132        op: BinaryOp,
1133        right: &LogicalExpression,
1134    ) -> Option<f64> {
1135        // Extract property access and literal value
1136        let (label, column, value) = self.extract_column_and_value(left, right)?;
1137
1138        // Get column stats
1139        let stats = self.get_column_stats(&label, &column)?;
1140
1141        // Determine the range based on operator
1142        let (lower, upper) = match op {
1143            BinaryOp::Lt => (None, Some(value)),
1144            BinaryOp::Le => (None, Some(value + f64::EPSILON)),
1145            BinaryOp::Gt => (Some(value + f64::EPSILON), None),
1146            BinaryOp::Ge => (Some(value), None),
1147            _ => return None,
1148        };
1149
1150        // Try histogram-based estimation first
1151        if let Some(ref histogram) = stats.histogram {
1152            return Some(histogram.range_selectivity(lower, upper));
1153        }
1154
1155        // Fall back to min/max range estimation
1156        if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
1157            let range = max - min;
1158            if range <= 0.0 {
1159                return Some(1.0);
1160            }
1161
1162            let effective_lower = lower.unwrap_or(min).max(min);
1163            let effective_upper = upper.unwrap_or(max).min(max);
1164            let overlap = (effective_upper - effective_lower).max(0.0);
1165            return Some((overlap / range).clamp(0.0, 1.0));
1166        }
1167
1168        None
1169    }
1170
1171    /// Extracts column information and literal value from a comparison.
1172    ///
1173    /// Returns (label, column_name, numeric_value) if the expression is
1174    /// a comparison between a property access and a numeric literal.
1175    fn extract_column_and_value(
1176        &self,
1177        left: &LogicalExpression,
1178        right: &LogicalExpression,
1179    ) -> Option<(String, String, f64)> {
1180        // Try left as property, right as literal
1181        if let Some(result) = self.try_extract_property_literal(left, right) {
1182            return Some(result);
1183        }
1184
1185        // Try right as property, left as literal
1186        self.try_extract_property_literal(right, left)
1187    }
1188
1189    /// Tries to extract property and literal from a specific ordering.
1190    fn try_extract_property_literal(
1191        &self,
1192        property_expr: &LogicalExpression,
1193        literal_expr: &LogicalExpression,
1194    ) -> Option<(String, String, f64)> {
1195        // Extract property access
1196        let (variable, property) = match property_expr {
1197            LogicalExpression::Property { variable, property } => {
1198                (variable.clone(), property.clone())
1199            }
1200            _ => return None,
1201        };
1202
1203        // Extract numeric literal
1204        let value = match literal_expr {
1205            LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1206            LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1207            _ => return None,
1208        };
1209
1210        // Try to find a label for this variable from table stats
1211        // Use the variable name as a heuristic label lookup
1212        // In practice, the optimizer would track which labels variables are bound to
1213        for label in self.table_stats.keys() {
1214            if let Some(stats) = self.table_stats.get(label)
1215                && stats.columns.contains_key(&property)
1216            {
1217                return Some((label.clone(), property, value));
1218            }
1219        }
1220
1221        // If no stats found but we have the property, return with variable as label
1222        Some((variable, property, value))
1223    }
1224
1225    /// Estimates unary expression selectivity.
1226    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1227        match op {
1228            UnaryOp::Not => 1.0 - self.default_selectivity,
1229            UnaryOp::IsNull => self.selectivity_config.is_null,
1230            UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1231            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
1232        }
1233    }
1234
1235    /// Gets statistics for a column.
1236    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1237        self.table_stats.get(label)?.columns.get(column)
1238    }
1239}
1240
1241impl Default for CardinalityEstimator {
1242    fn default() -> Self {
1243        Self::new()
1244    }
1245}
1246
1247#[cfg(test)]
1248mod tests {
1249    use super::*;
1250    use crate::query::plan::{
1251        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1252        ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1253    };
1254    use grafeo_common::types::Value;
1255
1256    #[test]
1257    fn test_node_scan_with_stats() {
1258        let mut estimator = CardinalityEstimator::new();
1259        estimator.add_table_stats("Person", TableStats::new(5000));
1260
1261        let scan = LogicalOperator::NodeScan(NodeScanOp {
1262            variable: "n".to_string(),
1263            label: Some("Person".to_string()),
1264            input: None,
1265        });
1266
1267        let cardinality = estimator.estimate(&scan);
1268        assert!((cardinality - 5000.0).abs() < 0.001);
1269    }
1270
1271    #[test]
1272    fn test_filter_reduces_cardinality() {
1273        let mut estimator = CardinalityEstimator::new();
1274        estimator.add_table_stats("Person", TableStats::new(1000));
1275
1276        let filter = LogicalOperator::Filter(FilterOp {
1277            predicate: LogicalExpression::Binary {
1278                left: Box::new(LogicalExpression::Property {
1279                    variable: "n".to_string(),
1280                    property: "age".to_string(),
1281                }),
1282                op: BinaryOp::Eq,
1283                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1284            },
1285            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1286                variable: "n".to_string(),
1287                label: Some("Person".to_string()),
1288                input: None,
1289            })),
1290            pushdown_hint: None,
1291        });
1292
1293        let cardinality = estimator.estimate(&filter);
1294        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
1295        assert!(cardinality < 1000.0);
1296        assert!(cardinality >= 1.0);
1297    }
1298
1299    #[test]
1300    fn test_join_cardinality() {
1301        let mut estimator = CardinalityEstimator::new();
1302        estimator.add_table_stats("Person", TableStats::new(1000));
1303        estimator.add_table_stats("Company", TableStats::new(100));
1304
1305        let join = LogicalOperator::Join(JoinOp {
1306            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1307                variable: "p".to_string(),
1308                label: Some("Person".to_string()),
1309                input: None,
1310            })),
1311            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1312                variable: "c".to_string(),
1313                label: Some("Company".to_string()),
1314                input: None,
1315            })),
1316            join_type: JoinType::Inner,
1317            conditions: vec![JoinCondition {
1318                left: LogicalExpression::Property {
1319                    variable: "p".to_string(),
1320                    property: "company_id".to_string(),
1321                },
1322                right: LogicalExpression::Property {
1323                    variable: "c".to_string(),
1324                    property: "id".to_string(),
1325                },
1326            }],
1327        });
1328
1329        let cardinality = estimator.estimate(&join);
1330        // Should be less than cross product
1331        assert!(cardinality < 1000.0 * 100.0);
1332    }
1333
1334    #[test]
1335    fn test_limit_caps_cardinality() {
1336        let mut estimator = CardinalityEstimator::new();
1337        estimator.add_table_stats("Person", TableStats::new(1000));
1338
1339        let limit = LogicalOperator::Limit(LimitOp {
1340            count: 10.into(),
1341            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1342                variable: "n".to_string(),
1343                label: Some("Person".to_string()),
1344                input: None,
1345            })),
1346        });
1347
1348        let cardinality = estimator.estimate(&limit);
1349        assert!((cardinality - 10.0).abs() < 0.001);
1350    }
1351
1352    #[test]
1353    fn test_aggregate_reduces_cardinality() {
1354        let mut estimator = CardinalityEstimator::new();
1355        estimator.add_table_stats("Person", TableStats::new(1000));
1356
1357        // Global aggregation
1358        let global_agg = LogicalOperator::Aggregate(AggregateOp {
1359            group_by: vec![],
1360            aggregates: vec![],
1361            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1362                variable: "n".to_string(),
1363                label: Some("Person".to_string()),
1364                input: None,
1365            })),
1366            having: None,
1367        });
1368
1369        let cardinality = estimator.estimate(&global_agg);
1370        assert!((cardinality - 1.0).abs() < 0.001);
1371
1372        // Group by aggregation
1373        let group_agg = LogicalOperator::Aggregate(AggregateOp {
1374            group_by: vec![LogicalExpression::Property {
1375                variable: "n".to_string(),
1376                property: "city".to_string(),
1377            }],
1378            aggregates: vec![],
1379            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1380                variable: "n".to_string(),
1381                label: Some("Person".to_string()),
1382                input: None,
1383            })),
1384            having: None,
1385        });
1386
1387        let cardinality = estimator.estimate(&group_agg);
1388        // Should be less than input
1389        assert!(cardinality < 1000.0);
1390    }
1391
1392    #[test]
1393    fn test_node_scan_without_stats() {
1394        let estimator = CardinalityEstimator::new();
1395
1396        let scan = LogicalOperator::NodeScan(NodeScanOp {
1397            variable: "n".to_string(),
1398            label: Some("Unknown".to_string()),
1399            input: None,
1400        });
1401
1402        let cardinality = estimator.estimate(&scan);
1403        // Should return default (1000)
1404        assert!((cardinality - 1000.0).abs() < 0.001);
1405    }
1406
1407    #[test]
1408    fn test_node_scan_no_label() {
1409        let estimator = CardinalityEstimator::new();
1410
1411        let scan = LogicalOperator::NodeScan(NodeScanOp {
1412            variable: "n".to_string(),
1413            label: None,
1414            input: None,
1415        });
1416
1417        let cardinality = estimator.estimate(&scan);
1418        // Should scan all nodes (default)
1419        assert!((cardinality - 1000.0).abs() < 0.001);
1420    }
1421
1422    #[test]
1423    fn test_filter_inequality_selectivity() {
1424        let mut estimator = CardinalityEstimator::new();
1425        estimator.add_table_stats("Person", TableStats::new(1000));
1426
1427        let filter = LogicalOperator::Filter(FilterOp {
1428            predicate: LogicalExpression::Binary {
1429                left: Box::new(LogicalExpression::Property {
1430                    variable: "n".to_string(),
1431                    property: "age".to_string(),
1432                }),
1433                op: BinaryOp::Ne,
1434                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1435            },
1436            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1437                variable: "n".to_string(),
1438                label: Some("Person".to_string()),
1439                input: None,
1440            })),
1441            pushdown_hint: None,
1442        });
1443
1444        let cardinality = estimator.estimate(&filter);
1445        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
1446        assert!(cardinality > 900.0);
1447    }
1448
1449    #[test]
1450    fn test_filter_range_selectivity() {
1451        let mut estimator = CardinalityEstimator::new();
1452        estimator.add_table_stats("Person", TableStats::new(1000));
1453
1454        let filter = LogicalOperator::Filter(FilterOp {
1455            predicate: LogicalExpression::Binary {
1456                left: Box::new(LogicalExpression::Property {
1457                    variable: "n".to_string(),
1458                    property: "age".to_string(),
1459                }),
1460                op: BinaryOp::Gt,
1461                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1462            },
1463            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1464                variable: "n".to_string(),
1465                label: Some("Person".to_string()),
1466                input: None,
1467            })),
1468            pushdown_hint: None,
1469        });
1470
1471        let cardinality = estimator.estimate(&filter);
1472        // Range selectivity is 0.33, so 1000 * 0.33 = 330
1473        assert!(cardinality < 500.0);
1474        assert!(cardinality > 100.0);
1475    }
1476
1477    #[test]
1478    fn test_filter_and_selectivity() {
1479        let mut estimator = CardinalityEstimator::new();
1480        estimator.add_table_stats("Person", TableStats::new(1000));
1481
1482        // Test AND with two equality predicates
1483        // Each equality has selectivity 0.01, so AND gives 0.01 * 0.01 = 0.0001
1484        let filter = LogicalOperator::Filter(FilterOp {
1485            predicate: LogicalExpression::Binary {
1486                left: Box::new(LogicalExpression::Binary {
1487                    left: Box::new(LogicalExpression::Property {
1488                        variable: "n".to_string(),
1489                        property: "city".to_string(),
1490                    }),
1491                    op: BinaryOp::Eq,
1492                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1493                }),
1494                op: BinaryOp::And,
1495                right: Box::new(LogicalExpression::Binary {
1496                    left: Box::new(LogicalExpression::Property {
1497                        variable: "n".to_string(),
1498                        property: "age".to_string(),
1499                    }),
1500                    op: BinaryOp::Eq,
1501                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1502                }),
1503            },
1504            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1505                variable: "n".to_string(),
1506                label: Some("Person".to_string()),
1507                input: None,
1508            })),
1509            pushdown_hint: None,
1510        });
1511
1512        let cardinality = estimator.estimate(&filter);
1513        // AND reduces selectivity (multiply): 0.01 * 0.01 = 0.0001
1514        // 1000 * 0.0001 = 0.1, min is 1.0
1515        assert!(cardinality < 100.0);
1516        assert!(cardinality >= 1.0);
1517    }
1518
1519    #[test]
1520    fn test_filter_or_selectivity() {
1521        let mut estimator = CardinalityEstimator::new();
1522        estimator.add_table_stats("Person", TableStats::new(1000));
1523
1524        // Test OR with two equality predicates
1525        // Each equality has selectivity 0.01
1526        // OR gives: 0.01 + 0.01 - (0.01 * 0.01) = 0.0199
1527        let filter = LogicalOperator::Filter(FilterOp {
1528            predicate: LogicalExpression::Binary {
1529                left: Box::new(LogicalExpression::Binary {
1530                    left: Box::new(LogicalExpression::Property {
1531                        variable: "n".to_string(),
1532                        property: "city".to_string(),
1533                    }),
1534                    op: BinaryOp::Eq,
1535                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1536                }),
1537                op: BinaryOp::Or,
1538                right: Box::new(LogicalExpression::Binary {
1539                    left: Box::new(LogicalExpression::Property {
1540                        variable: "n".to_string(),
1541                        property: "city".to_string(),
1542                    }),
1543                    op: BinaryOp::Eq,
1544                    right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1545                }),
1546            },
1547            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1548                variable: "n".to_string(),
1549                label: Some("Person".to_string()),
1550                input: None,
1551            })),
1552            pushdown_hint: None,
1553        });
1554
1555        let cardinality = estimator.estimate(&filter);
1556        // OR: 0.01 + 0.01 - 0.0001 ≈ 0.0199, so 1000 * 0.0199 ≈ 19.9
1557        assert!(cardinality < 100.0);
1558        assert!(cardinality >= 1.0);
1559    }
1560
1561    #[test]
1562    fn test_filter_literal_true() {
1563        let mut estimator = CardinalityEstimator::new();
1564        estimator.add_table_stats("Person", TableStats::new(1000));
1565
1566        let filter = LogicalOperator::Filter(FilterOp {
1567            predicate: LogicalExpression::Literal(Value::Bool(true)),
1568            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1569                variable: "n".to_string(),
1570                label: Some("Person".to_string()),
1571                input: None,
1572            })),
1573            pushdown_hint: None,
1574        });
1575
1576        let cardinality = estimator.estimate(&filter);
1577        // Literal true has selectivity 1.0
1578        assert!((cardinality - 1000.0).abs() < 0.001);
1579    }
1580
1581    #[test]
1582    fn test_filter_literal_false() {
1583        let mut estimator = CardinalityEstimator::new();
1584        estimator.add_table_stats("Person", TableStats::new(1000));
1585
1586        let filter = LogicalOperator::Filter(FilterOp {
1587            predicate: LogicalExpression::Literal(Value::Bool(false)),
1588            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1589                variable: "n".to_string(),
1590                label: Some("Person".to_string()),
1591                input: None,
1592            })),
1593            pushdown_hint: None,
1594        });
1595
1596        let cardinality = estimator.estimate(&filter);
1597        // Literal false has selectivity 0.0, but min is 1.0
1598        assert!((cardinality - 1.0).abs() < 0.001);
1599    }
1600
1601    #[test]
1602    fn test_unary_not_selectivity() {
1603        let mut estimator = CardinalityEstimator::new();
1604        estimator.add_table_stats("Person", TableStats::new(1000));
1605
1606        let filter = LogicalOperator::Filter(FilterOp {
1607            predicate: LogicalExpression::Unary {
1608                op: UnaryOp::Not,
1609                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1610            },
1611            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1612                variable: "n".to_string(),
1613                label: Some("Person".to_string()),
1614                input: None,
1615            })),
1616            pushdown_hint: None,
1617        });
1618
1619        let cardinality = estimator.estimate(&filter);
1620        // NOT inverts selectivity
1621        assert!(cardinality < 1000.0);
1622    }
1623
1624    #[test]
1625    fn test_unary_is_null_selectivity() {
1626        let mut estimator = CardinalityEstimator::new();
1627        estimator.add_table_stats("Person", TableStats::new(1000));
1628
1629        let filter = LogicalOperator::Filter(FilterOp {
1630            predicate: LogicalExpression::Unary {
1631                op: UnaryOp::IsNull,
1632                operand: Box::new(LogicalExpression::Variable("x".to_string())),
1633            },
1634            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1635                variable: "n".to_string(),
1636                label: Some("Person".to_string()),
1637                input: None,
1638            })),
1639            pushdown_hint: None,
1640        });
1641
1642        let cardinality = estimator.estimate(&filter);
1643        // IS NULL has selectivity 0.05
1644        assert!(cardinality < 100.0);
1645    }
1646
1647    #[test]
1648    fn test_expand_cardinality() {
1649        let mut estimator = CardinalityEstimator::new();
1650        estimator.add_table_stats("Person", TableStats::new(100));
1651
1652        let expand = LogicalOperator::Expand(ExpandOp {
1653            from_variable: "a".to_string(),
1654            to_variable: "b".to_string(),
1655            edge_variable: None,
1656            direction: ExpandDirection::Outgoing,
1657            edge_types: vec![],
1658            min_hops: 1,
1659            max_hops: Some(1),
1660            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1661                variable: "a".to_string(),
1662                label: Some("Person".to_string()),
1663                input: None,
1664            })),
1665            path_alias: None,
1666            path_mode: PathMode::Walk,
1667        });
1668
1669        let cardinality = estimator.estimate(&expand);
1670        // Expand multiplies by fanout (10)
1671        assert!(cardinality > 100.0);
1672    }
1673
1674    #[test]
1675    fn test_expand_with_edge_type_filter() {
1676        let mut estimator = CardinalityEstimator::new();
1677        estimator.add_table_stats("Person", TableStats::new(100));
1678
1679        let expand = LogicalOperator::Expand(ExpandOp {
1680            from_variable: "a".to_string(),
1681            to_variable: "b".to_string(),
1682            edge_variable: None,
1683            direction: ExpandDirection::Outgoing,
1684            edge_types: vec!["KNOWS".to_string()],
1685            min_hops: 1,
1686            max_hops: Some(1),
1687            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1688                variable: "a".to_string(),
1689                label: Some("Person".to_string()),
1690                input: None,
1691            })),
1692            path_alias: None,
1693            path_mode: PathMode::Walk,
1694        });
1695
1696        let cardinality = estimator.estimate(&expand);
1697        // With edge type, fanout is reduced by half
1698        assert!(cardinality > 100.0);
1699    }
1700
1701    #[test]
1702    fn test_expand_variable_length() {
1703        let mut estimator = CardinalityEstimator::new();
1704        estimator.add_table_stats("Person", TableStats::new(100));
1705
1706        let expand = LogicalOperator::Expand(ExpandOp {
1707            from_variable: "a".to_string(),
1708            to_variable: "b".to_string(),
1709            edge_variable: None,
1710            direction: ExpandDirection::Outgoing,
1711            edge_types: vec![],
1712            min_hops: 1,
1713            max_hops: Some(3),
1714            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1715                variable: "a".to_string(),
1716                label: Some("Person".to_string()),
1717                input: None,
1718            })),
1719            path_alias: None,
1720            path_mode: PathMode::Walk,
1721        });
1722
1723        let cardinality = estimator.estimate(&expand);
1724        // Variable length path has much higher cardinality
1725        assert!(cardinality > 500.0);
1726    }
1727
1728    #[test]
1729    fn test_join_cross_product() {
1730        let mut estimator = CardinalityEstimator::new();
1731        estimator.add_table_stats("Person", TableStats::new(100));
1732        estimator.add_table_stats("Company", TableStats::new(50));
1733
1734        let join = LogicalOperator::Join(JoinOp {
1735            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1736                variable: "p".to_string(),
1737                label: Some("Person".to_string()),
1738                input: None,
1739            })),
1740            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1741                variable: "c".to_string(),
1742                label: Some("Company".to_string()),
1743                input: None,
1744            })),
1745            join_type: JoinType::Cross,
1746            conditions: vec![],
1747        });
1748
1749        let cardinality = estimator.estimate(&join);
1750        // Cross join = 100 * 50 = 5000
1751        assert!((cardinality - 5000.0).abs() < 0.001);
1752    }
1753
1754    #[test]
1755    fn test_join_left_outer() {
1756        let mut estimator = CardinalityEstimator::new();
1757        estimator.add_table_stats("Person", TableStats::new(1000));
1758        estimator.add_table_stats("Company", TableStats::new(10));
1759
1760        let join = LogicalOperator::Join(JoinOp {
1761            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1762                variable: "p".to_string(),
1763                label: Some("Person".to_string()),
1764                input: None,
1765            })),
1766            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1767                variable: "c".to_string(),
1768                label: Some("Company".to_string()),
1769                input: None,
1770            })),
1771            join_type: JoinType::Left,
1772            conditions: vec![JoinCondition {
1773                left: LogicalExpression::Variable("p".to_string()),
1774                right: LogicalExpression::Variable("c".to_string()),
1775            }],
1776        });
1777
1778        let cardinality = estimator.estimate(&join);
1779        // Left join returns at least all left rows
1780        assert!(cardinality >= 1000.0);
1781    }
1782
1783    #[test]
1784    fn test_join_semi() {
1785        let mut estimator = CardinalityEstimator::new();
1786        estimator.add_table_stats("Person", TableStats::new(1000));
1787        estimator.add_table_stats("Company", TableStats::new(100));
1788
1789        let join = LogicalOperator::Join(JoinOp {
1790            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1791                variable: "p".to_string(),
1792                label: Some("Person".to_string()),
1793                input: None,
1794            })),
1795            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1796                variable: "c".to_string(),
1797                label: Some("Company".to_string()),
1798                input: None,
1799            })),
1800            join_type: JoinType::Semi,
1801            conditions: vec![],
1802        });
1803
1804        let cardinality = estimator.estimate(&join);
1805        // Semi join returns at most left cardinality
1806        assert!(cardinality <= 1000.0);
1807    }
1808
1809    #[test]
1810    fn test_join_anti() {
1811        let mut estimator = CardinalityEstimator::new();
1812        estimator.add_table_stats("Person", TableStats::new(1000));
1813        estimator.add_table_stats("Company", TableStats::new(100));
1814
1815        let join = LogicalOperator::Join(JoinOp {
1816            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1817                variable: "p".to_string(),
1818                label: Some("Person".to_string()),
1819                input: None,
1820            })),
1821            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1822                variable: "c".to_string(),
1823                label: Some("Company".to_string()),
1824                input: None,
1825            })),
1826            join_type: JoinType::Anti,
1827            conditions: vec![],
1828        });
1829
1830        let cardinality = estimator.estimate(&join);
1831        // Anti join returns at most left cardinality
1832        assert!(cardinality <= 1000.0);
1833        assert!(cardinality >= 1.0);
1834    }
1835
1836    #[test]
1837    fn test_project_preserves_cardinality() {
1838        let mut estimator = CardinalityEstimator::new();
1839        estimator.add_table_stats("Person", TableStats::new(1000));
1840
1841        let project = LogicalOperator::Project(ProjectOp {
1842            projections: vec![Projection {
1843                expression: LogicalExpression::Variable("n".to_string()),
1844                alias: None,
1845            }],
1846            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1847                variable: "n".to_string(),
1848                label: Some("Person".to_string()),
1849                input: None,
1850            })),
1851            pass_through_input: false,
1852        });
1853
1854        let cardinality = estimator.estimate(&project);
1855        assert!((cardinality - 1000.0).abs() < 0.001);
1856    }
1857
1858    #[test]
1859    fn test_sort_preserves_cardinality() {
1860        let mut estimator = CardinalityEstimator::new();
1861        estimator.add_table_stats("Person", TableStats::new(1000));
1862
1863        let sort = LogicalOperator::Sort(SortOp {
1864            keys: vec![SortKey {
1865                expression: LogicalExpression::Variable("n".to_string()),
1866                order: SortOrder::Ascending,
1867                nulls: None,
1868            }],
1869            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1870                variable: "n".to_string(),
1871                label: Some("Person".to_string()),
1872                input: None,
1873            })),
1874        });
1875
1876        let cardinality = estimator.estimate(&sort);
1877        assert!((cardinality - 1000.0).abs() < 0.001);
1878    }
1879
1880    #[test]
1881    fn test_distinct_reduces_cardinality() {
1882        let mut estimator = CardinalityEstimator::new();
1883        estimator.add_table_stats("Person", TableStats::new(1000));
1884
1885        let distinct = LogicalOperator::Distinct(DistinctOp {
1886            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1887                variable: "n".to_string(),
1888                label: Some("Person".to_string()),
1889                input: None,
1890            })),
1891            columns: None,
1892        });
1893
1894        let cardinality = estimator.estimate(&distinct);
1895        // Distinct assumes 50% unique
1896        assert!((cardinality - 500.0).abs() < 0.001);
1897    }
1898
1899    #[test]
1900    fn test_skip_reduces_cardinality() {
1901        let mut estimator = CardinalityEstimator::new();
1902        estimator.add_table_stats("Person", TableStats::new(1000));
1903
1904        let skip = LogicalOperator::Skip(SkipOp {
1905            count: 100.into(),
1906            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1907                variable: "n".to_string(),
1908                label: Some("Person".to_string()),
1909                input: None,
1910            })),
1911        });
1912
1913        let cardinality = estimator.estimate(&skip);
1914        assert!((cardinality - 900.0).abs() < 0.001);
1915    }
1916
1917    #[test]
1918    fn test_return_preserves_cardinality() {
1919        let mut estimator = CardinalityEstimator::new();
1920        estimator.add_table_stats("Person", TableStats::new(1000));
1921
1922        let ret = LogicalOperator::Return(ReturnOp {
1923            items: vec![ReturnItem {
1924                expression: LogicalExpression::Variable("n".to_string()),
1925                alias: None,
1926            }],
1927            distinct: false,
1928            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1929                variable: "n".to_string(),
1930                label: Some("Person".to_string()),
1931                input: None,
1932            })),
1933        });
1934
1935        let cardinality = estimator.estimate(&ret);
1936        assert!((cardinality - 1000.0).abs() < 0.001);
1937    }
1938
1939    #[test]
1940    fn test_empty_cardinality() {
1941        let estimator = CardinalityEstimator::new();
1942        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1943        assert!((cardinality).abs() < 0.001);
1944    }
1945
1946    #[test]
1947    fn test_table_stats_with_column() {
1948        let stats = TableStats::new(1000).with_column(
1949            "age",
1950            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1951        );
1952
1953        assert_eq!(stats.row_count, 1000);
1954        let col = stats.columns.get("age").unwrap();
1955        assert_eq!(col.distinct_count, 50);
1956        assert_eq!(col.null_count, 10);
1957        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1958        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1959    }
1960
1961    #[test]
1962    fn test_estimator_default() {
1963        let estimator = CardinalityEstimator::default();
1964        let scan = LogicalOperator::NodeScan(NodeScanOp {
1965            variable: "n".to_string(),
1966            label: None,
1967            input: None,
1968        });
1969        let cardinality = estimator.estimate(&scan);
1970        assert!((cardinality - 1000.0).abs() < 0.001);
1971    }
1972
1973    #[test]
1974    fn test_set_avg_fanout() {
1975        let mut estimator = CardinalityEstimator::new();
1976        estimator.add_table_stats("Person", TableStats::new(100));
1977        estimator.set_avg_fanout(5.0);
1978
1979        let expand = LogicalOperator::Expand(ExpandOp {
1980            from_variable: "a".to_string(),
1981            to_variable: "b".to_string(),
1982            edge_variable: None,
1983            direction: ExpandDirection::Outgoing,
1984            edge_types: vec![],
1985            min_hops: 1,
1986            max_hops: Some(1),
1987            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1988                variable: "a".to_string(),
1989                label: Some("Person".to_string()),
1990                input: None,
1991            })),
1992            path_alias: None,
1993            path_mode: PathMode::Walk,
1994        });
1995
1996        let cardinality = estimator.estimate(&expand);
1997        // With fanout 5: 100 * 5 = 500
1998        assert!((cardinality - 500.0).abs() < 0.001);
1999    }
2000
2001    #[test]
2002    fn test_multiple_group_by_keys_reduce_cardinality() {
2003        // The current implementation uses a simplified model where more group by keys
2004        // results in greater reduction (dividing by 10^num_keys). This is a simplification
2005        // that works for most cases where group by keys are correlated.
2006        let mut estimator = CardinalityEstimator::new();
2007        estimator.add_table_stats("Person", TableStats::new(10000));
2008
2009        let single_group = LogicalOperator::Aggregate(AggregateOp {
2010            group_by: vec![LogicalExpression::Property {
2011                variable: "n".to_string(),
2012                property: "city".to_string(),
2013            }],
2014            aggregates: vec![],
2015            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2016                variable: "n".to_string(),
2017                label: Some("Person".to_string()),
2018                input: None,
2019            })),
2020            having: None,
2021        });
2022
2023        let multi_group = LogicalOperator::Aggregate(AggregateOp {
2024            group_by: vec![
2025                LogicalExpression::Property {
2026                    variable: "n".to_string(),
2027                    property: "city".to_string(),
2028                },
2029                LogicalExpression::Property {
2030                    variable: "n".to_string(),
2031                    property: "country".to_string(),
2032                },
2033            ],
2034            aggregates: vec![],
2035            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2036                variable: "n".to_string(),
2037                label: Some("Person".to_string()),
2038                input: None,
2039            })),
2040            having: None,
2041        });
2042
2043        let single_card = estimator.estimate(&single_group);
2044        let multi_card = estimator.estimate(&multi_group);
2045
2046        // Both should reduce cardinality from input
2047        assert!(single_card < 10000.0);
2048        assert!(multi_card < 10000.0);
2049        // Both should be at least 1
2050        assert!(single_card >= 1.0);
2051        assert!(multi_card >= 1.0);
2052    }
2053
2054    // ============= Histogram Tests =============
2055
2056    #[test]
2057    fn test_histogram_build_uniform() {
2058        // Build histogram from uniformly distributed data
2059        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2060        let histogram = EquiDepthHistogram::build(&values, 10);
2061
2062        assert_eq!(histogram.num_buckets(), 10);
2063        assert_eq!(histogram.total_rows(), 100);
2064
2065        // Each bucket should have approximately 10 rows
2066        for bucket in histogram.buckets() {
2067            assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
2068        }
2069    }
2070
2071    #[test]
2072    fn test_histogram_build_skewed() {
2073        // Build histogram from skewed data (many small values, few large)
2074        let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
2075        values.extend((0..20).map(|i| 1000.0 + i as f64));
2076        let histogram = EquiDepthHistogram::build(&values, 5);
2077
2078        assert_eq!(histogram.num_buckets(), 5);
2079        assert_eq!(histogram.total_rows(), 100);
2080
2081        // Each bucket should have ~20 rows despite skewed data
2082        for bucket in histogram.buckets() {
2083            assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
2084        }
2085    }
2086
2087    #[test]
2088    fn test_histogram_range_selectivity_full() {
2089        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2090        let histogram = EquiDepthHistogram::build(&values, 10);
2091
2092        // Full range should have selectivity ~1.0
2093        let selectivity = histogram.range_selectivity(None, None);
2094        assert!((selectivity - 1.0).abs() < 0.01);
2095    }
2096
2097    #[test]
2098    fn test_histogram_range_selectivity_half() {
2099        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2100        let histogram = EquiDepthHistogram::build(&values, 10);
2101
2102        // Values >= 50 should be ~50% (half the data)
2103        let selectivity = histogram.range_selectivity(Some(50.0), None);
2104        assert!(selectivity > 0.4 && selectivity < 0.6);
2105    }
2106
2107    #[test]
2108    fn test_histogram_range_selectivity_quarter() {
2109        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2110        let histogram = EquiDepthHistogram::build(&values, 10);
2111
2112        // Values < 25 should be ~25%
2113        let selectivity = histogram.range_selectivity(None, Some(25.0));
2114        assert!(selectivity > 0.2 && selectivity < 0.3);
2115    }
2116
2117    #[test]
2118    fn test_histogram_equality_selectivity() {
2119        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2120        let histogram = EquiDepthHistogram::build(&values, 10);
2121
2122        // Equality on 100 distinct values should be ~1%
2123        let selectivity = histogram.equality_selectivity(50.0);
2124        assert!(selectivity > 0.005 && selectivity < 0.02);
2125    }
2126
2127    #[test]
2128    fn test_histogram_empty() {
2129        let histogram = EquiDepthHistogram::build(&[], 10);
2130
2131        assert_eq!(histogram.num_buckets(), 0);
2132        assert_eq!(histogram.total_rows(), 0);
2133
2134        // Default selectivity for empty histogram
2135        let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
2136        assert!((selectivity - 0.33).abs() < 0.01);
2137    }
2138
2139    #[test]
2140    fn test_histogram_bucket_overlap() {
2141        let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
2142
2143        // Full overlap
2144        assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
2145
2146        // Half overlap (lower half)
2147        assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
2148
2149        // Half overlap (upper half)
2150        assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
2151
2152        // No overlap (below)
2153        assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
2154
2155        // No overlap (above)
2156        assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
2157    }
2158
2159    #[test]
2160    fn test_column_stats_from_values() {
2161        let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2162        let stats = ColumnStats::from_values(values, 4);
2163
2164        assert_eq!(stats.distinct_count, 5); // 10, 20, 30, 40, 50
2165        assert!(stats.min_value.is_some());
2166        assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2167        assert!(stats.max_value.is_some());
2168        assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2169        assert!(stats.histogram.is_some());
2170    }
2171
2172    #[test]
2173    fn test_column_stats_with_histogram_builder() {
2174        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2175        let histogram = EquiDepthHistogram::build(&values, 10);
2176
2177        let stats = ColumnStats::new(100)
2178            .with_range(0.0, 99.0)
2179            .with_histogram(histogram);
2180
2181        assert!(stats.histogram.is_some());
2182        assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2183    }
2184
2185    #[test]
2186    fn test_filter_with_histogram_stats() {
2187        let mut estimator = CardinalityEstimator::new();
2188
2189        // Create stats with histogram for age column
2190        let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2191        let histogram = EquiDepthHistogram::build(&age_values, 10);
2192        let age_stats = ColumnStats::new(62)
2193            .with_range(18.0, 79.0)
2194            .with_histogram(histogram);
2195
2196        estimator.add_table_stats(
2197            "Person",
2198            TableStats::new(1000).with_column("age", age_stats),
2199        );
2200
2201        // Filter: age > 50
2202        // Age range is 18-79, so >50 is about (79-50)/(79-18) = 29/61 ≈ 47.5%
2203        let filter = LogicalOperator::Filter(FilterOp {
2204            predicate: LogicalExpression::Binary {
2205                left: Box::new(LogicalExpression::Property {
2206                    variable: "n".to_string(),
2207                    property: "age".to_string(),
2208                }),
2209                op: BinaryOp::Gt,
2210                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2211            },
2212            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2213                variable: "n".to_string(),
2214                label: Some("Person".to_string()),
2215                input: None,
2216            })),
2217            pushdown_hint: None,
2218        });
2219
2220        let cardinality = estimator.estimate(&filter);
2221
2222        // With histogram, should get more accurate estimate than default 0.33
2223        // Expected: ~47.5% of 1000 = ~475
2224        assert!(cardinality > 300.0 && cardinality < 600.0);
2225    }
2226
2227    #[test]
2228    fn test_filter_equality_with_histogram() {
2229        let mut estimator = CardinalityEstimator::new();
2230
2231        // Create stats with histogram
2232        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2233        let histogram = EquiDepthHistogram::build(&values, 10);
2234        let stats = ColumnStats::new(100)
2235            .with_range(0.0, 99.0)
2236            .with_histogram(histogram);
2237
2238        estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2239
2240        // Filter: value = 50
2241        let filter = LogicalOperator::Filter(FilterOp {
2242            predicate: LogicalExpression::Binary {
2243                left: Box::new(LogicalExpression::Property {
2244                    variable: "d".to_string(),
2245                    property: "value".to_string(),
2246                }),
2247                op: BinaryOp::Eq,
2248                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2249            },
2250            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2251                variable: "d".to_string(),
2252                label: Some("Data".to_string()),
2253                input: None,
2254            })),
2255            pushdown_hint: None,
2256        });
2257
2258        let cardinality = estimator.estimate(&filter);
2259
2260        // With 100 distinct values, selectivity should be ~1%
2261        // 1000 * 0.01 = 10
2262        assert!((1.0..50.0).contains(&cardinality));
2263    }
2264
2265    #[test]
2266    fn test_histogram_min_max() {
2267        let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2268        let histogram = EquiDepthHistogram::build(&values, 2);
2269
2270        assert_eq!(histogram.min_value(), Some(5.0));
2271        // Max is the upper bound of the last bucket
2272        assert!(histogram.max_value().is_some());
2273    }
2274
2275    // ==================== SelectivityConfig Tests ====================
2276
2277    #[test]
2278    fn test_selectivity_config_defaults() {
2279        let config = SelectivityConfig::new();
2280        assert!((config.default - 0.1).abs() < f64::EPSILON);
2281        assert!((config.equality - 0.01).abs() < f64::EPSILON);
2282        assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2283        assert!((config.range - 0.33).abs() < f64::EPSILON);
2284        assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2285        assert!((config.membership - 0.1).abs() < f64::EPSILON);
2286        assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2287        assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2288        assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2289    }
2290
2291    #[test]
2292    fn test_custom_selectivity_config() {
2293        let config = SelectivityConfig {
2294            equality: 0.05,
2295            range: 0.25,
2296            ..SelectivityConfig::new()
2297        };
2298        let estimator = CardinalityEstimator::with_selectivity_config(config);
2299        assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2300        assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2301    }
2302
2303    #[test]
2304    fn test_custom_selectivity_affects_estimation() {
2305        // Default: equality = 0.01 → 1000 * 0.01 = 10
2306        let mut default_est = CardinalityEstimator::new();
2307        default_est.add_table_stats("Person", TableStats::new(1000));
2308
2309        let filter = LogicalOperator::Filter(FilterOp {
2310            predicate: LogicalExpression::Binary {
2311                left: Box::new(LogicalExpression::Property {
2312                    variable: "n".to_string(),
2313                    property: "name".to_string(),
2314                }),
2315                op: BinaryOp::Eq,
2316                right: Box::new(LogicalExpression::Literal(Value::String("Alix".into()))),
2317            },
2318            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2319                variable: "n".to_string(),
2320                label: Some("Person".to_string()),
2321                input: None,
2322            })),
2323            pushdown_hint: None,
2324        });
2325
2326        let default_card = default_est.estimate(&filter);
2327
2328        // Custom: equality = 0.2 → 1000 * 0.2 = 200
2329        let config = SelectivityConfig {
2330            equality: 0.2,
2331            ..SelectivityConfig::new()
2332        };
2333        let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2334        custom_est.add_table_stats("Person", TableStats::new(1000));
2335
2336        let custom_card = custom_est.estimate(&filter);
2337
2338        assert!(custom_card > default_card);
2339        assert!((custom_card - 200.0).abs() < 1.0);
2340    }
2341
2342    #[test]
2343    fn test_custom_range_selectivity() {
2344        let config = SelectivityConfig {
2345            range: 0.5,
2346            ..SelectivityConfig::new()
2347        };
2348        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2349        estimator.add_table_stats("Person", TableStats::new(1000));
2350
2351        let filter = LogicalOperator::Filter(FilterOp {
2352            predicate: LogicalExpression::Binary {
2353                left: Box::new(LogicalExpression::Property {
2354                    variable: "n".to_string(),
2355                    property: "age".to_string(),
2356                }),
2357                op: BinaryOp::Gt,
2358                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2359            },
2360            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2361                variable: "n".to_string(),
2362                label: Some("Person".to_string()),
2363                input: None,
2364            })),
2365            pushdown_hint: None,
2366        });
2367
2368        let cardinality = estimator.estimate(&filter);
2369        // 1000 * 0.5 = 500
2370        assert!((cardinality - 500.0).abs() < 1.0);
2371    }
2372
2373    #[test]
2374    fn test_custom_distinct_fraction() {
2375        let config = SelectivityConfig {
2376            distinct_fraction: 0.8,
2377            ..SelectivityConfig::new()
2378        };
2379        let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2380        estimator.add_table_stats("Person", TableStats::new(1000));
2381
2382        let distinct = LogicalOperator::Distinct(DistinctOp {
2383            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2384                variable: "n".to_string(),
2385                label: Some("Person".to_string()),
2386                input: None,
2387            })),
2388            columns: None,
2389        });
2390
2391        let cardinality = estimator.estimate(&distinct);
2392        // 1000 * 0.8 = 800
2393        assert!((cardinality - 800.0).abs() < 1.0);
2394    }
2395
2396    // ==================== EstimationLog Tests ====================
2397
2398    #[test]
2399    fn test_estimation_log_basic() {
2400        let mut log = EstimationLog::new(10.0);
2401        log.record("NodeScan(Person)", 1000.0, 1200.0);
2402        log.record("Filter(age > 30)", 100.0, 90.0);
2403
2404        assert_eq!(log.entries().len(), 2);
2405        assert!(!log.should_replan()); // 1.2x and 0.9x are within 10x threshold
2406    }
2407
2408    #[test]
2409    fn test_estimation_log_triggers_replan() {
2410        let mut log = EstimationLog::new(10.0);
2411        log.record("NodeScan(Person)", 100.0, 5000.0); // 50x underestimate
2412
2413        assert!(log.should_replan());
2414    }
2415
2416    #[test]
2417    fn test_estimation_log_overestimate_triggers_replan() {
2418        let mut log = EstimationLog::new(5.0);
2419        log.record("Filter", 1000.0, 100.0); // 10x overestimate → ratio = 0.1
2420
2421        assert!(log.should_replan()); // 0.1 < 1/5 = 0.2
2422    }
2423
2424    #[test]
2425    fn test_estimation_entry_error_ratio() {
2426        let entry = EstimationEntry {
2427            operator: "test".into(),
2428            estimated: 100.0,
2429            actual: 200.0,
2430        };
2431        assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2432
2433        let perfect = EstimationEntry {
2434            operator: "test".into(),
2435            estimated: 100.0,
2436            actual: 100.0,
2437        };
2438        assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2439
2440        let zero_est = EstimationEntry {
2441            operator: "test".into(),
2442            estimated: 0.0,
2443            actual: 0.0,
2444        };
2445        assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2446    }
2447
2448    #[test]
2449    fn test_estimation_log_max_error_ratio() {
2450        let mut log = EstimationLog::new(10.0);
2451        log.record("A", 100.0, 300.0); // 3x
2452        log.record("B", 100.0, 50.0); // 2x (normalized: 1/0.5 = 2)
2453        log.record("C", 100.0, 100.0); // 1x
2454
2455        assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2456    }
2457
2458    #[test]
2459    fn test_estimation_log_clear() {
2460        let mut log = EstimationLog::new(10.0);
2461        log.record("A", 100.0, 100.0);
2462        assert_eq!(log.entries().len(), 1);
2463
2464        log.clear();
2465        assert!(log.entries().is_empty());
2466        assert!(!log.should_replan());
2467    }
2468
2469    #[test]
2470    fn test_create_estimation_log() {
2471        let log = CardinalityEstimator::create_estimation_log();
2472        assert!(log.entries().is_empty());
2473        assert!(!log.should_replan());
2474    }
2475
2476    #[test]
2477    fn test_equality_selectivity_empty_histogram() {
2478        let hist = EquiDepthHistogram::new(vec![]);
2479        // Empty histogram returns fixed fallback
2480        assert_eq!(hist.equality_selectivity(5.0), 0.01);
2481    }
2482
2483    #[test]
2484    fn test_equality_selectivity_value_in_bucket() {
2485        let values: Vec<f64> = (1..=10).map(|i| i as f64).collect();
2486        let hist = EquiDepthHistogram::build(&values, 2);
2487        let sel = hist.equality_selectivity(3.0);
2488        assert!(sel > 0.0);
2489        assert!(sel <= 1.0);
2490    }
2491
2492    #[test]
2493    fn test_equality_selectivity_value_outside_all_buckets() {
2494        let values: Vec<f64> = (1..=10).map(|i| i as f64).collect();
2495        let hist = EquiDepthHistogram::build(&values, 2);
2496        // Value far outside range
2497        let sel = hist.equality_selectivity(9999.0);
2498        assert_eq!(sel, 0.001);
2499    }
2500
2501    #[test]
2502    fn test_histogram_min_max_empty() {
2503        let hist = EquiDepthHistogram::new(vec![]);
2504        assert_eq!(hist.min_value(), None);
2505        assert_eq!(hist.max_value(), None);
2506    }
2507
2508    #[test]
2509    fn test_histogram_min_max_single_bucket() {
2510        let hist = EquiDepthHistogram::new(vec![HistogramBucket::new(1.0, 10.0, 5, 5)]);
2511        assert_eq!(hist.min_value(), Some(1.0));
2512        assert_eq!(hist.max_value(), Some(10.0));
2513    }
2514
2515    #[test]
2516    fn test_histogram_min_max_multi_bucket() {
2517        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0];
2518        let hist = EquiDepthHistogram::build(&values, 3);
2519        let min = hist.min_value().unwrap();
2520        let max = hist.max_value().unwrap();
2521        assert!((min - 1.0).abs() < 1e-9, "min should be 1.0, got {min}");
2522        assert!(max >= 20.0, "max should be >= last value, got {max}");
2523    }
2524
2525    #[test]
2526    fn test_count_and_conjuncts_single_expression() {
2527        use crate::query::plan::LogicalExpression;
2528        let expr = LogicalExpression::Literal(Value::Bool(true));
2529        assert_eq!(count_and_conjuncts(&expr), 1);
2530    }
2531
2532    #[test]
2533    fn test_count_and_conjuncts_flat_and() {
2534        use crate::query::plan::{BinaryOp, LogicalExpression};
2535        let expr = LogicalExpression::Binary {
2536            left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
2537            op: BinaryOp::And,
2538            right: Box::new(LogicalExpression::Literal(Value::Bool(false))),
2539        };
2540        assert_eq!(count_and_conjuncts(&expr), 2);
2541    }
2542
2543    #[test]
2544    fn test_count_and_conjuncts_nested_and() {
2545        use crate::query::plan::{BinaryOp, LogicalExpression};
2546        let ab = LogicalExpression::Binary {
2547            left: Box::new(LogicalExpression::Literal(Value::Bool(true))),
2548            op: BinaryOp::And,
2549            right: Box::new(LogicalExpression::Literal(Value::Bool(false))),
2550        };
2551        let cd = LogicalExpression::Binary {
2552            left: Box::new(LogicalExpression::Literal(Value::Int64(1))),
2553            op: BinaryOp::And,
2554            right: Box::new(LogicalExpression::Literal(Value::Int64(2))),
2555        };
2556        let expr = LogicalExpression::Binary {
2557            left: Box::new(ab),
2558            op: BinaryOp::And,
2559            right: Box::new(cd),
2560        };
2561        assert_eq!(count_and_conjuncts(&expr), 4);
2562    }
2563
2564    #[test]
2565    fn test_count_distinct_empty() {
2566        assert_eq!(count_distinct(&[]), 0);
2567    }
2568
2569    #[test]
2570    fn test_count_distinct_all_unique() {
2571        assert_eq!(count_distinct(&[1.0, 2.0, 3.0, 4.0]), 4);
2572    }
2573
2574    #[test]
2575    fn test_count_distinct_with_duplicates() {
2576        assert_eq!(count_distinct(&[1.0, 1.0, 2.0, 2.0, 3.0]), 3);
2577    }
2578
2579    #[test]
2580    fn test_count_distinct_all_same() {
2581        assert_eq!(count_distinct(&[5.0, 5.0, 5.0]), 1);
2582    }
2583
2584    #[test]
2585    fn test_count_distinct_single_value() {
2586        assert_eq!(count_distinct(&[42.0]), 1);
2587    }
2588
2589    // ==================== Vector/Text Scan & AGM Tests ====================
2590
2591    /// Vector scan caps cardinality at k and applies selectivity when thresholds are set.
2592    #[test]
2593    fn test_estimate_vector_scan_topk_and_threshold() {
2594        use crate::query::plan::VectorScanOp;
2595
2596        let estimator = CardinalityEstimator::new();
2597
2598        // k=10, no threshold: result is exactly k.
2599        let plain = LogicalOperator::VectorScan(VectorScanOp {
2600            variable: "n".to_string(),
2601            index_name: None,
2602            property: "embedding".to_string(),
2603            label: None,
2604            query_vector: LogicalExpression::Variable("q".to_string()),
2605            k: Some(10),
2606            metric: None,
2607            min_similarity: None,
2608            max_distance: None,
2609            input: None,
2610        });
2611        let plain_card = estimator.estimate(&plain);
2612        assert!(plain_card <= 10.0);
2613        assert!((plain_card - 10.0).abs() < 1e-9);
2614
2615        // Similarity threshold applies 0.7 scaling factor.
2616        let with_threshold = LogicalOperator::VectorScan(VectorScanOp {
2617            variable: "n".to_string(),
2618            index_name: None,
2619            property: "embedding".to_string(),
2620            label: None,
2621            query_vector: LogicalExpression::Variable("q".to_string()),
2622            k: Some(10),
2623            metric: None,
2624            min_similarity: Some(0.8),
2625            max_distance: None,
2626            input: None,
2627        });
2628        let filtered = estimator.estimate(&with_threshold);
2629        assert!(filtered < plain_card);
2630        assert!(filtered >= 1.0);
2631        assert!((filtered - 7.0).abs() < 1e-9);
2632    }
2633
2634    /// Vector join respects k-per-input-row and threshold selectivity.
2635    /// (Closest analogue for "text scan top-k" in the existing engine.)
2636    #[test]
2637    fn test_estimate_text_scan_topk_and_threshold() {
2638        use crate::query::plan::VectorJoinOp;
2639
2640        let mut estimator = CardinalityEstimator::new();
2641        estimator.add_table_stats("Article", TableStats::new(40));
2642
2643        let input = LogicalOperator::NodeScan(NodeScanOp {
2644            variable: "a".to_string(),
2645            label: Some("Article".to_string()),
2646            input: None,
2647        });
2648
2649        // k=5, no threshold: card = input * k = 40 * 5 = 200.
2650        let plain = LogicalOperator::VectorJoin(VectorJoinOp {
2651            input: Box::new(input.clone()),
2652            left_vector_variable: None,
2653            left_property: None,
2654            query_vector: LogicalExpression::Variable("q".to_string()),
2655            right_variable: "m".to_string(),
2656            right_property: "emb".to_string(),
2657            right_label: None,
2658            index_name: None,
2659            k: 5,
2660            metric: None,
2661            min_similarity: None,
2662            max_distance: None,
2663            score_variable: None,
2664        });
2665        let plain_card = estimator.estimate(&plain);
2666        assert!((plain_card - 200.0).abs() < 1e-9);
2667
2668        // min_similarity applies 0.7 scaling: 200 * 0.7 = 140.
2669        let with_threshold = LogicalOperator::VectorJoin(VectorJoinOp {
2670            input: Box::new(input),
2671            left_vector_variable: None,
2672            left_property: None,
2673            query_vector: LogicalExpression::Variable("q".to_string()),
2674            right_variable: "m".to_string(),
2675            right_property: "emb".to_string(),
2676            right_label: None,
2677            index_name: None,
2678            k: 5,
2679            metric: None,
2680            min_similarity: Some(0.5),
2681            max_distance: None,
2682            score_variable: None,
2683        });
2684        let filtered = estimator.estimate(&with_threshold);
2685        assert!(filtered < plain_card);
2686        assert!((filtered - 140.0).abs() < 1e-9);
2687    }
2688
2689    /// 3-way join Person-Works-Company uses the AGM bound (min_card^(n/2))
2690    /// rather than the Cartesian product of inputs.
2691    #[test]
2692    fn test_estimate_multi_way_join_agm_bound() {
2693        let mut estimator = CardinalityEstimator::new();
2694        // min cardinality is 50 (Works)
2695        estimator.add_table_stats("Person", TableStats::new(1000));
2696        estimator.add_table_stats("Works", TableStats::new(50));
2697        estimator.add_table_stats("Company", TableStats::new(200));
2698
2699        let mwj = LogicalOperator::MultiWayJoin(MultiWayJoinOp {
2700            inputs: vec![
2701                LogicalOperator::NodeScan(NodeScanOp {
2702                    variable: "p".to_string(),
2703                    label: Some("Person".to_string()),
2704                    input: None,
2705                }),
2706                LogicalOperator::NodeScan(NodeScanOp {
2707                    variable: "w".to_string(),
2708                    label: Some("Works".to_string()),
2709                    input: None,
2710                }),
2711                LogicalOperator::NodeScan(NodeScanOp {
2712                    variable: "c".to_string(),
2713                    label: Some("Company".to_string()),
2714                    input: None,
2715                }),
2716            ],
2717            conditions: vec![],
2718            shared_variables: vec!["p".to_string()],
2719        });
2720
2721        let card = estimator.estimate(&mwj);
2722        // AGM: min^(n/2) = 50^(3/2) = 50^1.5 ~= 353.55
2723        let expected = 50.0_f64.powf(1.5);
2724        assert!(
2725            (card - expected).abs() < 0.01,
2726            "got {card}, expected {expected}"
2727        );
2728        // Cartesian would be 1000 * 50 * 200 = 10_000_000, far larger.
2729        assert!(card < 1000.0 * 50.0 * 200.0);
2730    }
2731
2732    /// Empty multi-way join returns zero (no inputs → no rows).
2733    #[test]
2734    fn test_estimate_multi_way_join_empty_inputs() {
2735        let estimator = CardinalityEstimator::new();
2736        let mwj = LogicalOperator::MultiWayJoin(MultiWayJoinOp {
2737            inputs: vec![],
2738            conditions: vec![],
2739            shared_variables: vec![],
2740        });
2741        assert!(estimator.estimate(&mwj).abs() < f64::EPSILON);
2742    }
2743
2744    /// Range predicate (`age BETWEEN 25 AND 65`) with no histogram but with
2745    /// min/max falls back to linear range-based selectivity.
2746    #[test]
2747    fn test_range_selectivity_with_histogram_fallback() {
2748        let mut estimator = CardinalityEstimator::new();
2749        // Column has range [18, 80] but no histogram.
2750        let age_stats = ColumnStats::new(62).with_range(18.0, 80.0);
2751        estimator.add_table_stats(
2752            "Person",
2753            TableStats::new(1000).with_column("age", age_stats),
2754        );
2755
2756        // Encode `age >= 25 AND age <= 65` as an AND of two range predicates.
2757        let predicate = LogicalExpression::Binary {
2758            left: Box::new(LogicalExpression::Binary {
2759                left: Box::new(LogicalExpression::Property {
2760                    variable: "n".to_string(),
2761                    property: "age".to_string(),
2762                }),
2763                op: BinaryOp::Ge,
2764                right: Box::new(LogicalExpression::Literal(Value::Int64(25))),
2765            }),
2766            op: BinaryOp::And,
2767            right: Box::new(LogicalExpression::Binary {
2768                left: Box::new(LogicalExpression::Property {
2769                    variable: "n".to_string(),
2770                    property: "age".to_string(),
2771                }),
2772                op: BinaryOp::Le,
2773                right: Box::new(LogicalExpression::Literal(Value::Int64(65))),
2774            }),
2775        };
2776        let filter = LogicalOperator::Filter(FilterOp {
2777            predicate,
2778            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2779                variable: "n".to_string(),
2780                label: Some("Person".to_string()),
2781                input: None,
2782            })),
2783            pushdown_hint: None,
2784        });
2785
2786        let card = estimator.estimate(&filter);
2787        // Each side uses min/max fallback: `>= 25` picks (80 - 25)/(80 - 18) ~= 0.89;
2788        // `<= 65` picks (65 - 18)/(80 - 18) ~= 0.76. AND multiplies under
2789        // the independence assumption. Final card must be well below 1000
2790        // and above the single-equality estimate.
2791        assert!(card < 1000.0);
2792        assert!(card > 10.0);
2793    }
2794}