Skip to main content

grafeo_engine/query/optimizer/
cardinality.rs

1//! Cardinality estimation for query optimization.
2//!
3//! Estimates the number of rows produced by each operator in a query plan.
4//!
5//! # Equi-Depth Histograms
6//!
7//! This module provides equi-depth histogram support for accurate selectivity
8//! estimation of range predicates. Unlike equi-width histograms that divide
9//! the value range into equal-sized buckets, equi-depth histograms divide
10//! the data into buckets with approximately equal numbers of rows.
11//!
12//! Benefits:
13//! - Better estimates for skewed data distributions
14//! - More accurate range selectivity than assuming uniform distribution
15//! - Adaptive to actual data characteristics
16
17use crate::query::plan::{
18    AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
19    LogicalExpression, LogicalOperator, NodeScanOp, ProjectOp, SkipOp, SortOp, UnaryOp,
20    VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24/// A bucket in an equi-depth histogram.
25///
26/// Each bucket represents a range of values and the frequency of rows
27/// falling within that range. In an equi-depth histogram, all buckets
28/// contain approximately the same number of rows.
29#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31    /// Lower bound of the bucket (inclusive).
32    pub lower_bound: f64,
33    /// Upper bound of the bucket (exclusive, except for the last bucket).
34    pub upper_bound: f64,
35    /// Number of rows in this bucket.
36    pub frequency: u64,
37    /// Number of distinct values in this bucket.
38    pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42    /// Creates a new histogram bucket.
43    #[must_use]
44    pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45        Self {
46            lower_bound,
47            upper_bound,
48            frequency,
49            distinct_count,
50        }
51    }
52
53    /// Returns the width of this bucket.
54    #[must_use]
55    pub fn width(&self) -> f64 {
56        self.upper_bound - self.lower_bound
57    }
58
59    /// Checks if a value falls within this bucket.
60    #[must_use]
61    pub fn contains(&self, value: f64) -> bool {
62        value >= self.lower_bound && value < self.upper_bound
63    }
64
65    /// Returns the fraction of this bucket covered by the given range.
66    #[must_use]
67    pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68        let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69        let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71        let bucket_width = self.width();
72        if bucket_width <= 0.0 {
73            return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74                1.0
75            } else {
76                0.0
77            };
78        }
79
80        let overlap = (effective_upper - effective_lower).max(0.0);
81        (overlap / bucket_width).min(1.0)
82    }
83}
84
85/// An equi-depth histogram for selectivity estimation.
86///
87/// Equi-depth histograms partition data into buckets where each bucket
88/// contains approximately the same number of rows. This provides more
89/// accurate selectivity estimates than assuming uniform distribution,
90/// especially for skewed data.
91///
92/// # Example
93///
94/// ```ignore
95/// use grafeo_engine::query::optimizer::cardinality::EquiDepthHistogram;
96///
97/// // Build a histogram from sorted values
98/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0];
99/// let histogram = EquiDepthHistogram::build(&values, 4);
100///
101/// // Estimate selectivity for age > 25
102/// let selectivity = histogram.range_selectivity(Some(25.0), None);
103/// ```
104#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106    /// The histogram buckets, sorted by lower_bound.
107    buckets: Vec<HistogramBucket>,
108    /// Total number of rows represented by this histogram.
109    total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113    /// Creates a new histogram from pre-built buckets.
114    #[must_use]
115    pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116        let total_rows = buckets.iter().map(|b| b.frequency).sum();
117        Self {
118            buckets,
119            total_rows,
120        }
121    }
122
123    /// Builds an equi-depth histogram from a sorted slice of values.
124    ///
125    /// # Arguments
126    /// * `values` - A sorted slice of numeric values
127    /// * `num_buckets` - The desired number of buckets
128    ///
129    /// # Returns
130    /// An equi-depth histogram with approximately equal row counts per bucket.
131    #[must_use]
132    pub fn build(values: &[f64], num_buckets: usize) -> Self {
133        if values.is_empty() || num_buckets == 0 {
134            return Self {
135                buckets: Vec::new(),
136                total_rows: 0,
137            };
138        }
139
140        let num_buckets = num_buckets.min(values.len());
141        let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142        let mut buckets = Vec::with_capacity(num_buckets);
143
144        let mut start_idx = 0;
145        while start_idx < values.len() {
146            let end_idx = (start_idx + rows_per_bucket).min(values.len());
147            let lower_bound = values[start_idx];
148            let upper_bound = if end_idx < values.len() {
149                values[end_idx]
150            } else {
151                // For the last bucket, extend slightly beyond the max value
152                values[end_idx - 1] + 1.0
153            };
154
155            // Count distinct values in this bucket
156            let bucket_values = &values[start_idx..end_idx];
157            let distinct_count = count_distinct(bucket_values);
158
159            buckets.push(HistogramBucket::new(
160                lower_bound,
161                upper_bound,
162                (end_idx - start_idx) as u64,
163                distinct_count,
164            ));
165
166            start_idx = end_idx;
167        }
168
169        Self::new(buckets)
170    }
171
172    /// Returns the number of buckets in this histogram.
173    #[must_use]
174    pub fn num_buckets(&self) -> usize {
175        self.buckets.len()
176    }
177
178    /// Returns the total number of rows represented.
179    #[must_use]
180    pub fn total_rows(&self) -> u64 {
181        self.total_rows
182    }
183
184    /// Returns the histogram buckets.
185    #[must_use]
186    pub fn buckets(&self) -> &[HistogramBucket] {
187        &self.buckets
188    }
189
190    /// Estimates selectivity for a range predicate.
191    ///
192    /// # Arguments
193    /// * `lower` - Lower bound (None for unbounded)
194    /// * `upper` - Upper bound (None for unbounded)
195    ///
196    /// # Returns
197    /// Estimated fraction of rows matching the range (0.0 to 1.0).
198    #[must_use]
199    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200        if self.buckets.is_empty() || self.total_rows == 0 {
201            return 0.33; // Default fallback
202        }
203
204        let mut matching_rows = 0.0;
205
206        for bucket in &self.buckets {
207            // Check if this bucket overlaps with the range
208            let bucket_lower = bucket.lower_bound;
209            let bucket_upper = bucket.upper_bound;
210
211            // Skip buckets entirely outside the range
212            if let Some(l) = lower {
213                if bucket_upper <= l {
214                    continue;
215                }
216            }
217            if let Some(u) = upper {
218                if bucket_lower >= u {
219                    continue;
220                }
221            }
222
223            // Calculate the fraction of this bucket covered by the range
224            let overlap = bucket.overlap_fraction(lower, upper);
225            matching_rows += overlap * bucket.frequency as f64;
226        }
227
228        (matching_rows / self.total_rows as f64).min(1.0).max(0.0)
229    }
230
231    /// Estimates selectivity for an equality predicate.
232    ///
233    /// Uses the distinct count within matching buckets for better accuracy.
234    #[must_use]
235    pub fn equality_selectivity(&self, value: f64) -> f64 {
236        if self.buckets.is_empty() || self.total_rows == 0 {
237            return 0.01; // Default fallback
238        }
239
240        // Find the bucket containing this value
241        for bucket in &self.buckets {
242            if bucket.contains(value) {
243                // Assume uniform distribution within bucket
244                if bucket.distinct_count > 0 {
245                    return (bucket.frequency as f64
246                        / bucket.distinct_count as f64
247                        / self.total_rows as f64)
248                        .min(1.0);
249                }
250            }
251        }
252
253        // Value not in any bucket - very low selectivity
254        0.001
255    }
256
257    /// Gets the minimum value in the histogram.
258    #[must_use]
259    pub fn min_value(&self) -> Option<f64> {
260        self.buckets.first().map(|b| b.lower_bound)
261    }
262
263    /// Gets the maximum value in the histogram.
264    #[must_use]
265    pub fn max_value(&self) -> Option<f64> {
266        self.buckets.last().map(|b| b.upper_bound)
267    }
268}
269
270/// Counts distinct values in a sorted slice.
271fn count_distinct(sorted_values: &[f64]) -> u64 {
272    if sorted_values.is_empty() {
273        return 0;
274    }
275
276    let mut count = 1u64;
277    let mut prev = sorted_values[0];
278
279    for &val in &sorted_values[1..] {
280        if (val - prev).abs() > f64::EPSILON {
281            count += 1;
282            prev = val;
283        }
284    }
285
286    count
287}
288
289/// Statistics for a table/label.
290#[derive(Debug, Clone)]
291pub struct TableStats {
292    /// Total number of rows.
293    pub row_count: u64,
294    /// Column statistics.
295    pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299    /// Creates new table statistics.
300    #[must_use]
301    pub fn new(row_count: u64) -> Self {
302        Self {
303            row_count,
304            columns: HashMap::new(),
305        }
306    }
307
308    /// Adds column statistics.
309    pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310        self.columns.insert(name.to_string(), stats);
311        self
312    }
313}
314
315/// Statistics for a column.
316#[derive(Debug, Clone)]
317pub struct ColumnStats {
318    /// Number of distinct values.
319    pub distinct_count: u64,
320    /// Number of null values.
321    pub null_count: u64,
322    /// Minimum value (if orderable).
323    pub min_value: Option<f64>,
324    /// Maximum value (if orderable).
325    pub max_value: Option<f64>,
326    /// Equi-depth histogram for accurate selectivity estimation.
327    pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331    /// Creates new column statistics.
332    #[must_use]
333    pub fn new(distinct_count: u64) -> Self {
334        Self {
335            distinct_count,
336            null_count: 0,
337            min_value: None,
338            max_value: None,
339            histogram: None,
340        }
341    }
342
343    /// Sets the null count.
344    #[must_use]
345    pub fn with_nulls(mut self, null_count: u64) -> Self {
346        self.null_count = null_count;
347        self
348    }
349
350    /// Sets the min/max range.
351    #[must_use]
352    pub fn with_range(mut self, min: f64, max: f64) -> Self {
353        self.min_value = Some(min);
354        self.max_value = Some(max);
355        self
356    }
357
358    /// Sets the equi-depth histogram for this column.
359    #[must_use]
360    pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361        self.histogram = Some(histogram);
362        self
363    }
364
365    /// Builds column statistics with histogram from raw values.
366    ///
367    /// This is a convenience method that computes all statistics from the data.
368    ///
369    /// # Arguments
370    /// * `values` - The column values (will be sorted internally)
371    /// * `num_buckets` - Number of histogram buckets to create
372    #[must_use]
373    pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374        if values.is_empty() {
375            return Self::new(0);
376        }
377
378        // Sort values for histogram building
379        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381        let min = values.first().copied();
382        let max = values.last().copied();
383        let distinct_count = count_distinct(&values);
384        let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386        Self {
387            distinct_count,
388            null_count: 0,
389            min_value: min,
390            max_value: max,
391            histogram: Some(histogram),
392        }
393    }
394}
395
396/// Cardinality estimator.
397pub struct CardinalityEstimator {
398    /// Statistics for each label/table.
399    table_stats: HashMap<String, TableStats>,
400    /// Default row count for unknown tables.
401    default_row_count: u64,
402    /// Default selectivity for unknown predicates.
403    default_selectivity: f64,
404    /// Average edge fanout (outgoing edges per node).
405    avg_fanout: f64,
406}
407
408impl CardinalityEstimator {
409    /// Creates a new cardinality estimator with default settings.
410    #[must_use]
411    pub fn new() -> Self {
412        Self {
413            table_stats: HashMap::new(),
414            default_row_count: 1000,
415            default_selectivity: 0.1,
416            avg_fanout: 10.0,
417        }
418    }
419
420    /// Adds statistics for a table/label.
421    pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
422        self.table_stats.insert(name.to_string(), stats);
423    }
424
425    /// Sets the average edge fanout.
426    pub fn set_avg_fanout(&mut self, fanout: f64) {
427        self.avg_fanout = fanout;
428    }
429
430    /// Estimates the cardinality of a logical operator.
431    #[must_use]
432    pub fn estimate(&self, op: &LogicalOperator) -> f64 {
433        match op {
434            LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
435            LogicalOperator::Filter(filter) => self.estimate_filter(filter),
436            LogicalOperator::Project(project) => self.estimate_project(project),
437            LogicalOperator::Expand(expand) => self.estimate_expand(expand),
438            LogicalOperator::Join(join) => self.estimate_join(join),
439            LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
440            LogicalOperator::Sort(sort) => self.estimate_sort(sort),
441            LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
442            LogicalOperator::Limit(limit) => self.estimate_limit(limit),
443            LogicalOperator::Skip(skip) => self.estimate_skip(skip),
444            LogicalOperator::Return(ret) => self.estimate(&ret.input),
445            LogicalOperator::Empty => 0.0,
446            LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
447            LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
448            _ => self.default_row_count as f64,
449        }
450    }
451
452    /// Estimates node scan cardinality.
453    fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
454        if let Some(label) = &scan.label {
455            if let Some(stats) = self.table_stats.get(label) {
456                return stats.row_count as f64;
457            }
458        }
459        // No label filter - scan all nodes
460        self.default_row_count as f64
461    }
462
463    /// Estimates filter cardinality.
464    fn estimate_filter(&self, filter: &FilterOp) -> f64 {
465        let input_cardinality = self.estimate(&filter.input);
466        let selectivity = self.estimate_selectivity(&filter.predicate);
467        (input_cardinality * selectivity).max(1.0)
468    }
469
470    /// Estimates projection cardinality (same as input).
471    fn estimate_project(&self, project: &ProjectOp) -> f64 {
472        self.estimate(&project.input)
473    }
474
475    /// Estimates expand cardinality.
476    fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
477        let input_cardinality = self.estimate(&expand.input);
478
479        // Apply fanout based on edge type
480        let fanout = if expand.edge_type.is_some() {
481            // Specific edge type typically has lower fanout
482            self.avg_fanout * 0.5
483        } else {
484            self.avg_fanout
485        };
486
487        // Handle variable-length paths
488        let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
489            let min = expand.min_hops as f64;
490            let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
491            // Geometric series approximation
492            (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
493        } else {
494            fanout
495        };
496
497        (input_cardinality * path_multiplier).max(1.0)
498    }
499
500    /// Estimates join cardinality.
501    fn estimate_join(&self, join: &JoinOp) -> f64 {
502        let left_card = self.estimate(&join.left);
503        let right_card = self.estimate(&join.right);
504
505        match join.join_type {
506            JoinType::Cross => left_card * right_card,
507            JoinType::Inner => {
508                // Assume join selectivity based on conditions
509                let selectivity = if join.conditions.is_empty() {
510                    1.0 // Cross join
511                } else {
512                    // Estimate based on number of conditions
513                    0.1_f64.powi(join.conditions.len() as i32)
514                };
515                (left_card * right_card * selectivity).max(1.0)
516            }
517            JoinType::Left => {
518                // Left join returns at least all left rows
519                let inner_card = self.estimate_join(&JoinOp {
520                    left: join.left.clone(),
521                    right: join.right.clone(),
522                    join_type: JoinType::Inner,
523                    conditions: join.conditions.clone(),
524                });
525                inner_card.max(left_card)
526            }
527            JoinType::Right => {
528                // Right join returns at least all right rows
529                let inner_card = self.estimate_join(&JoinOp {
530                    left: join.left.clone(),
531                    right: join.right.clone(),
532                    join_type: JoinType::Inner,
533                    conditions: join.conditions.clone(),
534                });
535                inner_card.max(right_card)
536            }
537            JoinType::Full => {
538                // Full join returns at least max(left, right)
539                let inner_card = self.estimate_join(&JoinOp {
540                    left: join.left.clone(),
541                    right: join.right.clone(),
542                    join_type: JoinType::Inner,
543                    conditions: join.conditions.clone(),
544                });
545                inner_card.max(left_card.max(right_card))
546            }
547            JoinType::Semi => {
548                // Semi join returns at most left cardinality
549                (left_card * self.default_selectivity).max(1.0)
550            }
551            JoinType::Anti => {
552                // Anti join returns at most left cardinality
553                (left_card * (1.0 - self.default_selectivity)).max(1.0)
554            }
555        }
556    }
557
558    /// Estimates aggregation cardinality.
559    fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
560        let input_cardinality = self.estimate(&agg.input);
561
562        if agg.group_by.is_empty() {
563            // Global aggregation - single row
564            1.0
565        } else {
566            // Group by - estimate distinct groups
567            // Assume each group key reduces cardinality by 10
568            let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
569            (input_cardinality / group_reduction).max(1.0)
570        }
571    }
572
573    /// Estimates sort cardinality (same as input).
574    fn estimate_sort(&self, sort: &SortOp) -> f64 {
575        self.estimate(&sort.input)
576    }
577
578    /// Estimates distinct cardinality.
579    fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
580        let input_cardinality = self.estimate(&distinct.input);
581        // Assume 50% distinct by default
582        (input_cardinality * 0.5).max(1.0)
583    }
584
585    /// Estimates limit cardinality.
586    fn estimate_limit(&self, limit: &LimitOp) -> f64 {
587        let input_cardinality = self.estimate(&limit.input);
588        (limit.count as f64).min(input_cardinality)
589    }
590
591    /// Estimates skip cardinality.
592    fn estimate_skip(&self, skip: &SkipOp) -> f64 {
593        let input_cardinality = self.estimate(&skip.input);
594        (input_cardinality - skip.count as f64).max(0.0)
595    }
596
597    /// Estimates vector scan cardinality.
598    ///
599    /// Vector scan returns at most k results (the k nearest neighbors).
600    /// With similarity/distance filters, it may return fewer.
601    fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
602        let base_k = scan.k as f64;
603
604        // Apply filter selectivity if thresholds are specified
605        let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
606            // Assume 70% of results pass threshold filters
607            0.7
608        } else {
609            1.0
610        };
611
612        (base_k * selectivity).max(1.0)
613    }
614
615    /// Estimates vector join cardinality.
616    ///
617    /// Vector join produces up to k results per input row.
618    fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
619        let input_cardinality = self.estimate(&join.input);
620        let k = join.k as f64;
621
622        // Apply filter selectivity if thresholds are specified
623        let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
624            0.7
625        } else {
626            1.0
627        };
628
629        (input_cardinality * k * selectivity).max(1.0)
630    }
631
632    /// Estimates the selectivity of a predicate (0.0 to 1.0).
633    fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
634        match expr {
635            LogicalExpression::Binary { left, op, right } => {
636                self.estimate_binary_selectivity(left, *op, right)
637            }
638            LogicalExpression::Unary { op, operand } => {
639                self.estimate_unary_selectivity(*op, operand)
640            }
641            LogicalExpression::Literal(value) => {
642                // Boolean literal
643                if let grafeo_common::types::Value::Bool(b) = value {
644                    if *b { 1.0 } else { 0.0 }
645                } else {
646                    self.default_selectivity
647                }
648            }
649            _ => self.default_selectivity,
650        }
651    }
652
653    /// Estimates binary expression selectivity.
654    fn estimate_binary_selectivity(
655        &self,
656        left: &LogicalExpression,
657        op: BinaryOp,
658        right: &LogicalExpression,
659    ) -> f64 {
660        match op {
661            // Equality - try histogram-based estimation
662            BinaryOp::Eq => {
663                if let Some(selectivity) = self.try_equality_selectivity(left, right) {
664                    return selectivity;
665                }
666                0.01
667            }
668            // Inequality is very unselective
669            BinaryOp::Ne => 0.99,
670            // Range predicates - use histogram if available
671            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
672                if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
673                    return selectivity;
674                }
675                0.33
676            }
677            // Logical operators - recursively estimate sub-expressions
678            BinaryOp::And => {
679                let left_sel = self.estimate_selectivity(left);
680                let right_sel = self.estimate_selectivity(right);
681                // AND reduces selectivity (multiply assuming independence)
682                left_sel * right_sel
683            }
684            BinaryOp::Or => {
685                let left_sel = self.estimate_selectivity(left);
686                let right_sel = self.estimate_selectivity(right);
687                // OR: P(A ∪ B) = P(A) + P(B) - P(A ∩ B)
688                // Assuming independence: P(A ∩ B) = P(A) * P(B)
689                (left_sel + right_sel - left_sel * right_sel).min(1.0)
690            }
691            // String operations
692            BinaryOp::StartsWith => 0.1,
693            BinaryOp::EndsWith => 0.1,
694            BinaryOp::Contains => 0.1,
695            BinaryOp::Like => 0.1,
696            // Collection membership
697            BinaryOp::In => 0.1,
698            // Other operations
699            _ => self.default_selectivity,
700        }
701    }
702
703    /// Tries to estimate equality selectivity using histograms.
704    fn try_equality_selectivity(
705        &self,
706        left: &LogicalExpression,
707        right: &LogicalExpression,
708    ) -> Option<f64> {
709        // Extract property access and literal value
710        let (label, column, value) = self.extract_column_and_value(left, right)?;
711
712        // Get column stats with histogram
713        let stats = self.get_column_stats(&label, &column)?;
714
715        // Try histogram-based estimation
716        if let Some(ref histogram) = stats.histogram {
717            return Some(histogram.equality_selectivity(value));
718        }
719
720        // Fall back to distinct count estimation
721        if stats.distinct_count > 0 {
722            return Some(1.0 / stats.distinct_count as f64);
723        }
724
725        None
726    }
727
728    /// Tries to estimate range selectivity using histograms.
729    fn try_range_selectivity(
730        &self,
731        left: &LogicalExpression,
732        op: BinaryOp,
733        right: &LogicalExpression,
734    ) -> Option<f64> {
735        // Extract property access and literal value
736        let (label, column, value) = self.extract_column_and_value(left, right)?;
737
738        // Get column stats
739        let stats = self.get_column_stats(&label, &column)?;
740
741        // Determine the range based on operator
742        let (lower, upper) = match op {
743            BinaryOp::Lt => (None, Some(value)),
744            BinaryOp::Le => (None, Some(value + f64::EPSILON)),
745            BinaryOp::Gt => (Some(value + f64::EPSILON), None),
746            BinaryOp::Ge => (Some(value), None),
747            _ => return None,
748        };
749
750        // Try histogram-based estimation first
751        if let Some(ref histogram) = stats.histogram {
752            return Some(histogram.range_selectivity(lower, upper));
753        }
754
755        // Fall back to min/max range estimation
756        if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
757            let range = max - min;
758            if range <= 0.0 {
759                return Some(1.0);
760            }
761
762            let effective_lower = lower.unwrap_or(min).max(min);
763            let effective_upper = upper.unwrap_or(max).min(max);
764            let overlap = (effective_upper - effective_lower).max(0.0);
765            return Some((overlap / range).min(1.0).max(0.0));
766        }
767
768        None
769    }
770
771    /// Extracts column information and literal value from a comparison.
772    ///
773    /// Returns (label, column_name, numeric_value) if the expression is
774    /// a comparison between a property access and a numeric literal.
775    fn extract_column_and_value(
776        &self,
777        left: &LogicalExpression,
778        right: &LogicalExpression,
779    ) -> Option<(String, String, f64)> {
780        // Try left as property, right as literal
781        if let Some(result) = self.try_extract_property_literal(left, right) {
782            return Some(result);
783        }
784
785        // Try right as property, left as literal
786        self.try_extract_property_literal(right, left)
787    }
788
789    /// Tries to extract property and literal from a specific ordering.
790    fn try_extract_property_literal(
791        &self,
792        property_expr: &LogicalExpression,
793        literal_expr: &LogicalExpression,
794    ) -> Option<(String, String, f64)> {
795        // Extract property access
796        let (variable, property) = match property_expr {
797            LogicalExpression::Property { variable, property } => {
798                (variable.clone(), property.clone())
799            }
800            _ => return None,
801        };
802
803        // Extract numeric literal
804        let value = match literal_expr {
805            LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
806            LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
807            _ => return None,
808        };
809
810        // Try to find a label for this variable from table stats
811        // Use the variable name as a heuristic label lookup
812        // In practice, the optimizer would track which labels variables are bound to
813        for label in self.table_stats.keys() {
814            if let Some(stats) = self.table_stats.get(label) {
815                if stats.columns.contains_key(&property) {
816                    return Some((label.clone(), property, value));
817                }
818            }
819        }
820
821        // If no stats found but we have the property, return with variable as label
822        Some((variable, property, value))
823    }
824
825    /// Estimates unary expression selectivity.
826    fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
827        match op {
828            UnaryOp::Not => 1.0 - self.default_selectivity,
829            UnaryOp::IsNull => 0.05, // Assume 5% nulls
830            UnaryOp::IsNotNull => 0.95,
831            UnaryOp::Neg => 1.0, // Negation doesn't change cardinality
832        }
833    }
834
835    /// Gets statistics for a column.
836    fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
837        self.table_stats.get(label)?.columns.get(column)
838    }
839
840    /// Estimates equality selectivity using column statistics.
841    #[allow(dead_code)]
842    fn estimate_equality_with_stats(&self, label: &str, column: &str) -> f64 {
843        if let Some(stats) = self.get_column_stats(label, column) {
844            if stats.distinct_count > 0 {
845                return 1.0 / stats.distinct_count as f64;
846            }
847        }
848        0.01 // Default for equality
849    }
850
851    /// Estimates range selectivity using column statistics.
852    #[allow(dead_code)]
853    fn estimate_range_with_stats(
854        &self,
855        label: &str,
856        column: &str,
857        lower: Option<f64>,
858        upper: Option<f64>,
859    ) -> f64 {
860        if let Some(stats) = self.get_column_stats(label, column) {
861            if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
862                let range = max - min;
863                if range <= 0.0 {
864                    return 1.0;
865                }
866
867                let effective_lower = lower.unwrap_or(min).max(min);
868                let effective_upper = upper.unwrap_or(max).min(max);
869
870                let overlap = (effective_upper - effective_lower).max(0.0);
871                return (overlap / range).min(1.0).max(0.0);
872            }
873        }
874        0.33 // Default for range
875    }
876}
877
878impl Default for CardinalityEstimator {
879    fn default() -> Self {
880        Self::new()
881    }
882}
883
884#[cfg(test)]
885mod tests {
886    use super::*;
887    use crate::query::plan::{
888        DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, ProjectOp,
889        Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
890    };
891    use grafeo_common::types::Value;
892
893    #[test]
894    fn test_node_scan_with_stats() {
895        let mut estimator = CardinalityEstimator::new();
896        estimator.add_table_stats("Person", TableStats::new(5000));
897
898        let scan = LogicalOperator::NodeScan(NodeScanOp {
899            variable: "n".to_string(),
900            label: Some("Person".to_string()),
901            input: None,
902        });
903
904        let cardinality = estimator.estimate(&scan);
905        assert!((cardinality - 5000.0).abs() < 0.001);
906    }
907
908    #[test]
909    fn test_filter_reduces_cardinality() {
910        let mut estimator = CardinalityEstimator::new();
911        estimator.add_table_stats("Person", TableStats::new(1000));
912
913        let filter = LogicalOperator::Filter(FilterOp {
914            predicate: LogicalExpression::Binary {
915                left: Box::new(LogicalExpression::Property {
916                    variable: "n".to_string(),
917                    property: "age".to_string(),
918                }),
919                op: BinaryOp::Eq,
920                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
921            },
922            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
923                variable: "n".to_string(),
924                label: Some("Person".to_string()),
925                input: None,
926            })),
927        });
928
929        let cardinality = estimator.estimate(&filter);
930        // Equality selectivity is 0.01, so 1000 * 0.01 = 10
931        assert!(cardinality < 1000.0);
932        assert!(cardinality >= 1.0);
933    }
934
935    #[test]
936    fn test_join_cardinality() {
937        let mut estimator = CardinalityEstimator::new();
938        estimator.add_table_stats("Person", TableStats::new(1000));
939        estimator.add_table_stats("Company", TableStats::new(100));
940
941        let join = LogicalOperator::Join(JoinOp {
942            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
943                variable: "p".to_string(),
944                label: Some("Person".to_string()),
945                input: None,
946            })),
947            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
948                variable: "c".to_string(),
949                label: Some("Company".to_string()),
950                input: None,
951            })),
952            join_type: JoinType::Inner,
953            conditions: vec![JoinCondition {
954                left: LogicalExpression::Property {
955                    variable: "p".to_string(),
956                    property: "company_id".to_string(),
957                },
958                right: LogicalExpression::Property {
959                    variable: "c".to_string(),
960                    property: "id".to_string(),
961                },
962            }],
963        });
964
965        let cardinality = estimator.estimate(&join);
966        // Should be less than cross product
967        assert!(cardinality < 1000.0 * 100.0);
968    }
969
970    #[test]
971    fn test_limit_caps_cardinality() {
972        let mut estimator = CardinalityEstimator::new();
973        estimator.add_table_stats("Person", TableStats::new(1000));
974
975        let limit = LogicalOperator::Limit(LimitOp {
976            count: 10,
977            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
978                variable: "n".to_string(),
979                label: Some("Person".to_string()),
980                input: None,
981            })),
982        });
983
984        let cardinality = estimator.estimate(&limit);
985        assert!((cardinality - 10.0).abs() < 0.001);
986    }
987
988    #[test]
989    fn test_aggregate_reduces_cardinality() {
990        let mut estimator = CardinalityEstimator::new();
991        estimator.add_table_stats("Person", TableStats::new(1000));
992
993        // Global aggregation
994        let global_agg = LogicalOperator::Aggregate(AggregateOp {
995            group_by: vec![],
996            aggregates: vec![],
997            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
998                variable: "n".to_string(),
999                label: Some("Person".to_string()),
1000                input: None,
1001            })),
1002            having: None,
1003        });
1004
1005        let cardinality = estimator.estimate(&global_agg);
1006        assert!((cardinality - 1.0).abs() < 0.001);
1007
1008        // Group by aggregation
1009        let group_agg = LogicalOperator::Aggregate(AggregateOp {
1010            group_by: vec![LogicalExpression::Property {
1011                variable: "n".to_string(),
1012                property: "city".to_string(),
1013            }],
1014            aggregates: vec![],
1015            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1016                variable: "n".to_string(),
1017                label: Some("Person".to_string()),
1018                input: None,
1019            })),
1020            having: None,
1021        });
1022
1023        let cardinality = estimator.estimate(&group_agg);
1024        // Should be less than input
1025        assert!(cardinality < 1000.0);
1026    }
1027
1028    #[test]
1029    fn test_node_scan_without_stats() {
1030        let estimator = CardinalityEstimator::new();
1031
1032        let scan = LogicalOperator::NodeScan(NodeScanOp {
1033            variable: "n".to_string(),
1034            label: Some("Unknown".to_string()),
1035            input: None,
1036        });
1037
1038        let cardinality = estimator.estimate(&scan);
1039        // Should return default (1000)
1040        assert!((cardinality - 1000.0).abs() < 0.001);
1041    }
1042
1043    #[test]
1044    fn test_node_scan_no_label() {
1045        let estimator = CardinalityEstimator::new();
1046
1047        let scan = LogicalOperator::NodeScan(NodeScanOp {
1048            variable: "n".to_string(),
1049            label: None,
1050            input: None,
1051        });
1052
1053        let cardinality = estimator.estimate(&scan);
1054        // Should scan all nodes (default)
1055        assert!((cardinality - 1000.0).abs() < 0.001);
1056    }
1057
1058    #[test]
1059    fn test_filter_inequality_selectivity() {
1060        let mut estimator = CardinalityEstimator::new();
1061        estimator.add_table_stats("Person", TableStats::new(1000));
1062
1063        let filter = LogicalOperator::Filter(FilterOp {
1064            predicate: LogicalExpression::Binary {
1065                left: Box::new(LogicalExpression::Property {
1066                    variable: "n".to_string(),
1067                    property: "age".to_string(),
1068                }),
1069                op: BinaryOp::Ne,
1070                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1071            },
1072            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1073                variable: "n".to_string(),
1074                label: Some("Person".to_string()),
1075                input: None,
1076            })),
1077        });
1078
1079        let cardinality = estimator.estimate(&filter);
1080        // Inequality selectivity is 0.99, so 1000 * 0.99 = 990
1081        assert!(cardinality > 900.0);
1082    }
1083
1084    #[test]
1085    fn test_filter_range_selectivity() {
1086        let mut estimator = CardinalityEstimator::new();
1087        estimator.add_table_stats("Person", TableStats::new(1000));
1088
1089        let filter = LogicalOperator::Filter(FilterOp {
1090            predicate: LogicalExpression::Binary {
1091                left: Box::new(LogicalExpression::Property {
1092                    variable: "n".to_string(),
1093                    property: "age".to_string(),
1094                }),
1095                op: BinaryOp::Gt,
1096                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1097            },
1098            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1099                variable: "n".to_string(),
1100                label: Some("Person".to_string()),
1101                input: None,
1102            })),
1103        });
1104
1105        let cardinality = estimator.estimate(&filter);
1106        // Range selectivity is 0.33, so 1000 * 0.33 = 330
1107        assert!(cardinality < 500.0);
1108        assert!(cardinality > 100.0);
1109    }
1110
1111    #[test]
1112    fn test_filter_and_selectivity() {
1113        let mut estimator = CardinalityEstimator::new();
1114        estimator.add_table_stats("Person", TableStats::new(1000));
1115
1116        // Test AND with two equality predicates
1117        // Each equality has selectivity 0.01, so AND gives 0.01 * 0.01 = 0.0001
1118        let filter = LogicalOperator::Filter(FilterOp {
1119            predicate: LogicalExpression::Binary {
1120                left: Box::new(LogicalExpression::Binary {
1121                    left: Box::new(LogicalExpression::Property {
1122                        variable: "n".to_string(),
1123                        property: "city".to_string(),
1124                    }),
1125                    op: BinaryOp::Eq,
1126                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1127                }),
1128                op: BinaryOp::And,
1129                right: Box::new(LogicalExpression::Binary {
1130                    left: Box::new(LogicalExpression::Property {
1131                        variable: "n".to_string(),
1132                        property: "age".to_string(),
1133                    }),
1134                    op: BinaryOp::Eq,
1135                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1136                }),
1137            },
1138            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1139                variable: "n".to_string(),
1140                label: Some("Person".to_string()),
1141                input: None,
1142            })),
1143        });
1144
1145        let cardinality = estimator.estimate(&filter);
1146        // AND reduces selectivity (multiply): 0.01 * 0.01 = 0.0001
1147        // 1000 * 0.0001 = 0.1, min is 1.0
1148        assert!(cardinality < 100.0);
1149        assert!(cardinality >= 1.0);
1150    }
1151
1152    #[test]
1153    fn test_filter_or_selectivity() {
1154        let mut estimator = CardinalityEstimator::new();
1155        estimator.add_table_stats("Person", TableStats::new(1000));
1156
1157        // Test OR with two equality predicates
1158        // Each equality has selectivity 0.01
1159        // OR gives: 0.01 + 0.01 - (0.01 * 0.01) = 0.0199
1160        let filter = LogicalOperator::Filter(FilterOp {
1161            predicate: LogicalExpression::Binary {
1162                left: Box::new(LogicalExpression::Binary {
1163                    left: Box::new(LogicalExpression::Property {
1164                        variable: "n".to_string(),
1165                        property: "city".to_string(),
1166                    }),
1167                    op: BinaryOp::Eq,
1168                    right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1169                }),
1170                op: BinaryOp::Or,
1171                right: Box::new(LogicalExpression::Binary {
1172                    left: Box::new(LogicalExpression::Property {
1173                        variable: "n".to_string(),
1174                        property: "city".to_string(),
1175                    }),
1176                    op: BinaryOp::Eq,
1177                    right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1178                }),
1179            },
1180            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1181                variable: "n".to_string(),
1182                label: Some("Person".to_string()),
1183                input: None,
1184            })),
1185        });
1186
1187        let cardinality = estimator.estimate(&filter);
1188        // OR: 0.01 + 0.01 - 0.0001 ≈ 0.0199, so 1000 * 0.0199 ≈ 19.9
1189        assert!(cardinality < 100.0);
1190        assert!(cardinality >= 1.0);
1191    }
1192
1193    #[test]
1194    fn test_filter_literal_true() {
1195        let mut estimator = CardinalityEstimator::new();
1196        estimator.add_table_stats("Person", TableStats::new(1000));
1197
1198        let filter = LogicalOperator::Filter(FilterOp {
1199            predicate: LogicalExpression::Literal(Value::Bool(true)),
1200            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1201                variable: "n".to_string(),
1202                label: Some("Person".to_string()),
1203                input: None,
1204            })),
1205        });
1206
1207        let cardinality = estimator.estimate(&filter);
1208        // Literal true has selectivity 1.0
1209        assert!((cardinality - 1000.0).abs() < 0.001);
1210    }
1211
1212    #[test]
1213    fn test_filter_literal_false() {
1214        let mut estimator = CardinalityEstimator::new();
1215        estimator.add_table_stats("Person", TableStats::new(1000));
1216
1217        let filter = LogicalOperator::Filter(FilterOp {
1218            predicate: LogicalExpression::Literal(Value::Bool(false)),
1219            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1220                variable: "n".to_string(),
1221                label: Some("Person".to_string()),
1222                input: None,
1223            })),
1224        });
1225
1226        let cardinality = estimator.estimate(&filter);
1227        // Literal false has selectivity 0.0, but min is 1.0
1228        assert!((cardinality - 1.0).abs() < 0.001);
1229    }
1230
1231    #[test]
1232    fn test_unary_not_selectivity() {
1233        let mut estimator = CardinalityEstimator::new();
1234        estimator.add_table_stats("Person", TableStats::new(1000));
1235
1236        let filter = LogicalOperator::Filter(FilterOp {
1237            predicate: LogicalExpression::Unary {
1238                op: UnaryOp::Not,
1239                operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1240            },
1241            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1242                variable: "n".to_string(),
1243                label: Some("Person".to_string()),
1244                input: None,
1245            })),
1246        });
1247
1248        let cardinality = estimator.estimate(&filter);
1249        // NOT inverts selectivity
1250        assert!(cardinality < 1000.0);
1251    }
1252
1253    #[test]
1254    fn test_unary_is_null_selectivity() {
1255        let mut estimator = CardinalityEstimator::new();
1256        estimator.add_table_stats("Person", TableStats::new(1000));
1257
1258        let filter = LogicalOperator::Filter(FilterOp {
1259            predicate: LogicalExpression::Unary {
1260                op: UnaryOp::IsNull,
1261                operand: Box::new(LogicalExpression::Variable("x".to_string())),
1262            },
1263            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1264                variable: "n".to_string(),
1265                label: Some("Person".to_string()),
1266                input: None,
1267            })),
1268        });
1269
1270        let cardinality = estimator.estimate(&filter);
1271        // IS NULL has selectivity 0.05
1272        assert!(cardinality < 100.0);
1273    }
1274
1275    #[test]
1276    fn test_expand_cardinality() {
1277        let mut estimator = CardinalityEstimator::new();
1278        estimator.add_table_stats("Person", TableStats::new(100));
1279
1280        let expand = LogicalOperator::Expand(ExpandOp {
1281            from_variable: "a".to_string(),
1282            to_variable: "b".to_string(),
1283            edge_variable: None,
1284            direction: ExpandDirection::Outgoing,
1285            edge_type: None,
1286            min_hops: 1,
1287            max_hops: Some(1),
1288            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1289                variable: "a".to_string(),
1290                label: Some("Person".to_string()),
1291                input: None,
1292            })),
1293            path_alias: None,
1294        });
1295
1296        let cardinality = estimator.estimate(&expand);
1297        // Expand multiplies by fanout (10)
1298        assert!(cardinality > 100.0);
1299    }
1300
1301    #[test]
1302    fn test_expand_with_edge_type_filter() {
1303        let mut estimator = CardinalityEstimator::new();
1304        estimator.add_table_stats("Person", TableStats::new(100));
1305
1306        let expand = LogicalOperator::Expand(ExpandOp {
1307            from_variable: "a".to_string(),
1308            to_variable: "b".to_string(),
1309            edge_variable: None,
1310            direction: ExpandDirection::Outgoing,
1311            edge_type: Some("KNOWS".to_string()),
1312            min_hops: 1,
1313            max_hops: Some(1),
1314            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1315                variable: "a".to_string(),
1316                label: Some("Person".to_string()),
1317                input: None,
1318            })),
1319            path_alias: None,
1320        });
1321
1322        let cardinality = estimator.estimate(&expand);
1323        // With edge type, fanout is reduced by half
1324        assert!(cardinality > 100.0);
1325    }
1326
1327    #[test]
1328    fn test_expand_variable_length() {
1329        let mut estimator = CardinalityEstimator::new();
1330        estimator.add_table_stats("Person", TableStats::new(100));
1331
1332        let expand = LogicalOperator::Expand(ExpandOp {
1333            from_variable: "a".to_string(),
1334            to_variable: "b".to_string(),
1335            edge_variable: None,
1336            direction: ExpandDirection::Outgoing,
1337            edge_type: None,
1338            min_hops: 1,
1339            max_hops: Some(3),
1340            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1341                variable: "a".to_string(),
1342                label: Some("Person".to_string()),
1343                input: None,
1344            })),
1345            path_alias: None,
1346        });
1347
1348        let cardinality = estimator.estimate(&expand);
1349        // Variable length path has much higher cardinality
1350        assert!(cardinality > 500.0);
1351    }
1352
1353    #[test]
1354    fn test_join_cross_product() {
1355        let mut estimator = CardinalityEstimator::new();
1356        estimator.add_table_stats("Person", TableStats::new(100));
1357        estimator.add_table_stats("Company", TableStats::new(50));
1358
1359        let join = LogicalOperator::Join(JoinOp {
1360            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1361                variable: "p".to_string(),
1362                label: Some("Person".to_string()),
1363                input: None,
1364            })),
1365            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1366                variable: "c".to_string(),
1367                label: Some("Company".to_string()),
1368                input: None,
1369            })),
1370            join_type: JoinType::Cross,
1371            conditions: vec![],
1372        });
1373
1374        let cardinality = estimator.estimate(&join);
1375        // Cross join = 100 * 50 = 5000
1376        assert!((cardinality - 5000.0).abs() < 0.001);
1377    }
1378
1379    #[test]
1380    fn test_join_left_outer() {
1381        let mut estimator = CardinalityEstimator::new();
1382        estimator.add_table_stats("Person", TableStats::new(1000));
1383        estimator.add_table_stats("Company", TableStats::new(10));
1384
1385        let join = LogicalOperator::Join(JoinOp {
1386            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1387                variable: "p".to_string(),
1388                label: Some("Person".to_string()),
1389                input: None,
1390            })),
1391            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1392                variable: "c".to_string(),
1393                label: Some("Company".to_string()),
1394                input: None,
1395            })),
1396            join_type: JoinType::Left,
1397            conditions: vec![JoinCondition {
1398                left: LogicalExpression::Variable("p".to_string()),
1399                right: LogicalExpression::Variable("c".to_string()),
1400            }],
1401        });
1402
1403        let cardinality = estimator.estimate(&join);
1404        // Left join returns at least all left rows
1405        assert!(cardinality >= 1000.0);
1406    }
1407
1408    #[test]
1409    fn test_join_semi() {
1410        let mut estimator = CardinalityEstimator::new();
1411        estimator.add_table_stats("Person", TableStats::new(1000));
1412        estimator.add_table_stats("Company", TableStats::new(100));
1413
1414        let join = LogicalOperator::Join(JoinOp {
1415            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1416                variable: "p".to_string(),
1417                label: Some("Person".to_string()),
1418                input: None,
1419            })),
1420            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1421                variable: "c".to_string(),
1422                label: Some("Company".to_string()),
1423                input: None,
1424            })),
1425            join_type: JoinType::Semi,
1426            conditions: vec![],
1427        });
1428
1429        let cardinality = estimator.estimate(&join);
1430        // Semi join returns at most left cardinality
1431        assert!(cardinality <= 1000.0);
1432    }
1433
1434    #[test]
1435    fn test_join_anti() {
1436        let mut estimator = CardinalityEstimator::new();
1437        estimator.add_table_stats("Person", TableStats::new(1000));
1438        estimator.add_table_stats("Company", TableStats::new(100));
1439
1440        let join = LogicalOperator::Join(JoinOp {
1441            left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1442                variable: "p".to_string(),
1443                label: Some("Person".to_string()),
1444                input: None,
1445            })),
1446            right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1447                variable: "c".to_string(),
1448                label: Some("Company".to_string()),
1449                input: None,
1450            })),
1451            join_type: JoinType::Anti,
1452            conditions: vec![],
1453        });
1454
1455        let cardinality = estimator.estimate(&join);
1456        // Anti join returns at most left cardinality
1457        assert!(cardinality <= 1000.0);
1458        assert!(cardinality >= 1.0);
1459    }
1460
1461    #[test]
1462    fn test_project_preserves_cardinality() {
1463        let mut estimator = CardinalityEstimator::new();
1464        estimator.add_table_stats("Person", TableStats::new(1000));
1465
1466        let project = LogicalOperator::Project(ProjectOp {
1467            projections: vec![Projection {
1468                expression: LogicalExpression::Variable("n".to_string()),
1469                alias: None,
1470            }],
1471            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1472                variable: "n".to_string(),
1473                label: Some("Person".to_string()),
1474                input: None,
1475            })),
1476        });
1477
1478        let cardinality = estimator.estimate(&project);
1479        assert!((cardinality - 1000.0).abs() < 0.001);
1480    }
1481
1482    #[test]
1483    fn test_sort_preserves_cardinality() {
1484        let mut estimator = CardinalityEstimator::new();
1485        estimator.add_table_stats("Person", TableStats::new(1000));
1486
1487        let sort = LogicalOperator::Sort(SortOp {
1488            keys: vec![SortKey {
1489                expression: LogicalExpression::Variable("n".to_string()),
1490                order: SortOrder::Ascending,
1491            }],
1492            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1493                variable: "n".to_string(),
1494                label: Some("Person".to_string()),
1495                input: None,
1496            })),
1497        });
1498
1499        let cardinality = estimator.estimate(&sort);
1500        assert!((cardinality - 1000.0).abs() < 0.001);
1501    }
1502
1503    #[test]
1504    fn test_distinct_reduces_cardinality() {
1505        let mut estimator = CardinalityEstimator::new();
1506        estimator.add_table_stats("Person", TableStats::new(1000));
1507
1508        let distinct = LogicalOperator::Distinct(DistinctOp {
1509            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1510                variable: "n".to_string(),
1511                label: Some("Person".to_string()),
1512                input: None,
1513            })),
1514            columns: None,
1515        });
1516
1517        let cardinality = estimator.estimate(&distinct);
1518        // Distinct assumes 50% unique
1519        assert!((cardinality - 500.0).abs() < 0.001);
1520    }
1521
1522    #[test]
1523    fn test_skip_reduces_cardinality() {
1524        let mut estimator = CardinalityEstimator::new();
1525        estimator.add_table_stats("Person", TableStats::new(1000));
1526
1527        let skip = LogicalOperator::Skip(SkipOp {
1528            count: 100,
1529            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1530                variable: "n".to_string(),
1531                label: Some("Person".to_string()),
1532                input: None,
1533            })),
1534        });
1535
1536        let cardinality = estimator.estimate(&skip);
1537        assert!((cardinality - 900.0).abs() < 0.001);
1538    }
1539
1540    #[test]
1541    fn test_return_preserves_cardinality() {
1542        let mut estimator = CardinalityEstimator::new();
1543        estimator.add_table_stats("Person", TableStats::new(1000));
1544
1545        let ret = LogicalOperator::Return(ReturnOp {
1546            items: vec![ReturnItem {
1547                expression: LogicalExpression::Variable("n".to_string()),
1548                alias: None,
1549            }],
1550            distinct: false,
1551            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1552                variable: "n".to_string(),
1553                label: Some("Person".to_string()),
1554                input: None,
1555            })),
1556        });
1557
1558        let cardinality = estimator.estimate(&ret);
1559        assert!((cardinality - 1000.0).abs() < 0.001);
1560    }
1561
1562    #[test]
1563    fn test_empty_cardinality() {
1564        let estimator = CardinalityEstimator::new();
1565        let cardinality = estimator.estimate(&LogicalOperator::Empty);
1566        assert!((cardinality).abs() < 0.001);
1567    }
1568
1569    #[test]
1570    fn test_table_stats_with_column() {
1571        let stats = TableStats::new(1000).with_column(
1572            "age",
1573            ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1574        );
1575
1576        assert_eq!(stats.row_count, 1000);
1577        let col = stats.columns.get("age").unwrap();
1578        assert_eq!(col.distinct_count, 50);
1579        assert_eq!(col.null_count, 10);
1580        assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1581        assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1582    }
1583
1584    #[test]
1585    fn test_estimator_default() {
1586        let estimator = CardinalityEstimator::default();
1587        let scan = LogicalOperator::NodeScan(NodeScanOp {
1588            variable: "n".to_string(),
1589            label: None,
1590            input: None,
1591        });
1592        let cardinality = estimator.estimate(&scan);
1593        assert!((cardinality - 1000.0).abs() < 0.001);
1594    }
1595
1596    #[test]
1597    fn test_set_avg_fanout() {
1598        let mut estimator = CardinalityEstimator::new();
1599        estimator.add_table_stats("Person", TableStats::new(100));
1600        estimator.set_avg_fanout(5.0);
1601
1602        let expand = LogicalOperator::Expand(ExpandOp {
1603            from_variable: "a".to_string(),
1604            to_variable: "b".to_string(),
1605            edge_variable: None,
1606            direction: ExpandDirection::Outgoing,
1607            edge_type: None,
1608            min_hops: 1,
1609            max_hops: Some(1),
1610            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1611                variable: "a".to_string(),
1612                label: Some("Person".to_string()),
1613                input: None,
1614            })),
1615            path_alias: None,
1616        });
1617
1618        let cardinality = estimator.estimate(&expand);
1619        // With fanout 5: 100 * 5 = 500
1620        assert!((cardinality - 500.0).abs() < 0.001);
1621    }
1622
1623    #[test]
1624    fn test_multiple_group_by_keys_reduce_cardinality() {
1625        // The current implementation uses a simplified model where more group by keys
1626        // results in greater reduction (dividing by 10^num_keys). This is a simplification
1627        // that works for most cases where group by keys are correlated.
1628        let mut estimator = CardinalityEstimator::new();
1629        estimator.add_table_stats("Person", TableStats::new(10000));
1630
1631        let single_group = LogicalOperator::Aggregate(AggregateOp {
1632            group_by: vec![LogicalExpression::Property {
1633                variable: "n".to_string(),
1634                property: "city".to_string(),
1635            }],
1636            aggregates: vec![],
1637            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1638                variable: "n".to_string(),
1639                label: Some("Person".to_string()),
1640                input: None,
1641            })),
1642            having: None,
1643        });
1644
1645        let multi_group = LogicalOperator::Aggregate(AggregateOp {
1646            group_by: vec![
1647                LogicalExpression::Property {
1648                    variable: "n".to_string(),
1649                    property: "city".to_string(),
1650                },
1651                LogicalExpression::Property {
1652                    variable: "n".to_string(),
1653                    property: "country".to_string(),
1654                },
1655            ],
1656            aggregates: vec![],
1657            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1658                variable: "n".to_string(),
1659                label: Some("Person".to_string()),
1660                input: None,
1661            })),
1662            having: None,
1663        });
1664
1665        let single_card = estimator.estimate(&single_group);
1666        let multi_card = estimator.estimate(&multi_group);
1667
1668        // Both should reduce cardinality from input
1669        assert!(single_card < 10000.0);
1670        assert!(multi_card < 10000.0);
1671        // Both should be at least 1
1672        assert!(single_card >= 1.0);
1673        assert!(multi_card >= 1.0);
1674    }
1675
1676    // ============= Histogram Tests =============
1677
1678    #[test]
1679    fn test_histogram_build_uniform() {
1680        // Build histogram from uniformly distributed data
1681        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1682        let histogram = EquiDepthHistogram::build(&values, 10);
1683
1684        assert_eq!(histogram.num_buckets(), 10);
1685        assert_eq!(histogram.total_rows(), 100);
1686
1687        // Each bucket should have approximately 10 rows
1688        for bucket in histogram.buckets() {
1689            assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1690        }
1691    }
1692
1693    #[test]
1694    fn test_histogram_build_skewed() {
1695        // Build histogram from skewed data (many small values, few large)
1696        let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
1697        values.extend((0..20).map(|i| 1000.0 + i as f64));
1698        let histogram = EquiDepthHistogram::build(&values, 5);
1699
1700        assert_eq!(histogram.num_buckets(), 5);
1701        assert_eq!(histogram.total_rows(), 100);
1702
1703        // Each bucket should have ~20 rows despite skewed data
1704        for bucket in histogram.buckets() {
1705            assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
1706        }
1707    }
1708
1709    #[test]
1710    fn test_histogram_range_selectivity_full() {
1711        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1712        let histogram = EquiDepthHistogram::build(&values, 10);
1713
1714        // Full range should have selectivity ~1.0
1715        let selectivity = histogram.range_selectivity(None, None);
1716        assert!((selectivity - 1.0).abs() < 0.01);
1717    }
1718
1719    #[test]
1720    fn test_histogram_range_selectivity_half() {
1721        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1722        let histogram = EquiDepthHistogram::build(&values, 10);
1723
1724        // Values >= 50 should be ~50% (half the data)
1725        let selectivity = histogram.range_selectivity(Some(50.0), None);
1726        assert!(selectivity > 0.4 && selectivity < 0.6);
1727    }
1728
1729    #[test]
1730    fn test_histogram_range_selectivity_quarter() {
1731        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1732        let histogram = EquiDepthHistogram::build(&values, 10);
1733
1734        // Values < 25 should be ~25%
1735        let selectivity = histogram.range_selectivity(None, Some(25.0));
1736        assert!(selectivity > 0.2 && selectivity < 0.3);
1737    }
1738
1739    #[test]
1740    fn test_histogram_equality_selectivity() {
1741        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1742        let histogram = EquiDepthHistogram::build(&values, 10);
1743
1744        // Equality on 100 distinct values should be ~1%
1745        let selectivity = histogram.equality_selectivity(50.0);
1746        assert!(selectivity > 0.005 && selectivity < 0.02);
1747    }
1748
1749    #[test]
1750    fn test_histogram_empty() {
1751        let histogram = EquiDepthHistogram::build(&[], 10);
1752
1753        assert_eq!(histogram.num_buckets(), 0);
1754        assert_eq!(histogram.total_rows(), 0);
1755
1756        // Default selectivity for empty histogram
1757        let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
1758        assert!((selectivity - 0.33).abs() < 0.01);
1759    }
1760
1761    #[test]
1762    fn test_histogram_bucket_overlap() {
1763        let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
1764
1765        // Full overlap
1766        assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
1767
1768        // Half overlap (lower half)
1769        assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
1770
1771        // Half overlap (upper half)
1772        assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
1773
1774        // No overlap (below)
1775        assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
1776
1777        // No overlap (above)
1778        assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
1779    }
1780
1781    #[test]
1782    fn test_column_stats_from_values() {
1783        let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
1784        let stats = ColumnStats::from_values(values, 4);
1785
1786        assert_eq!(stats.distinct_count, 5); // 10, 20, 30, 40, 50
1787        assert!(stats.min_value.is_some());
1788        assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
1789        assert!(stats.max_value.is_some());
1790        assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
1791        assert!(stats.histogram.is_some());
1792    }
1793
1794    #[test]
1795    fn test_column_stats_with_histogram_builder() {
1796        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1797        let histogram = EquiDepthHistogram::build(&values, 10);
1798
1799        let stats = ColumnStats::new(100)
1800            .with_range(0.0, 99.0)
1801            .with_histogram(histogram);
1802
1803        assert!(stats.histogram.is_some());
1804        assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
1805    }
1806
1807    #[test]
1808    fn test_filter_with_histogram_stats() {
1809        let mut estimator = CardinalityEstimator::new();
1810
1811        // Create stats with histogram for age column
1812        let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
1813        let histogram = EquiDepthHistogram::build(&age_values, 10);
1814        let age_stats = ColumnStats::new(62)
1815            .with_range(18.0, 79.0)
1816            .with_histogram(histogram);
1817
1818        estimator.add_table_stats(
1819            "Person",
1820            TableStats::new(1000).with_column("age", age_stats),
1821        );
1822
1823        // Filter: age > 50
1824        // Age range is 18-79, so >50 is about (79-50)/(79-18) = 29/61 ≈ 47.5%
1825        let filter = LogicalOperator::Filter(FilterOp {
1826            predicate: LogicalExpression::Binary {
1827                left: Box::new(LogicalExpression::Property {
1828                    variable: "n".to_string(),
1829                    property: "age".to_string(),
1830                }),
1831                op: BinaryOp::Gt,
1832                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1833            },
1834            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1835                variable: "n".to_string(),
1836                label: Some("Person".to_string()),
1837                input: None,
1838            })),
1839        });
1840
1841        let cardinality = estimator.estimate(&filter);
1842
1843        // With histogram, should get more accurate estimate than default 0.33
1844        // Expected: ~47.5% of 1000 = ~475
1845        assert!(cardinality > 300.0 && cardinality < 600.0);
1846    }
1847
1848    #[test]
1849    fn test_filter_equality_with_histogram() {
1850        let mut estimator = CardinalityEstimator::new();
1851
1852        // Create stats with histogram
1853        let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1854        let histogram = EquiDepthHistogram::build(&values, 10);
1855        let stats = ColumnStats::new(100)
1856            .with_range(0.0, 99.0)
1857            .with_histogram(histogram);
1858
1859        estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
1860
1861        // Filter: value = 50
1862        let filter = LogicalOperator::Filter(FilterOp {
1863            predicate: LogicalExpression::Binary {
1864                left: Box::new(LogicalExpression::Property {
1865                    variable: "d".to_string(),
1866                    property: "value".to_string(),
1867                }),
1868                op: BinaryOp::Eq,
1869                right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
1870            },
1871            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1872                variable: "d".to_string(),
1873                label: Some("Data".to_string()),
1874                input: None,
1875            })),
1876        });
1877
1878        let cardinality = estimator.estimate(&filter);
1879
1880        // With 100 distinct values, selectivity should be ~1%
1881        // 1000 * 0.01 = 10
1882        assert!(cardinality >= 1.0 && cardinality < 50.0);
1883    }
1884
1885    #[test]
1886    fn test_histogram_min_max() {
1887        let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
1888        let histogram = EquiDepthHistogram::build(&values, 2);
1889
1890        assert_eq!(histogram.min_value(), Some(5.0));
1891        // Max is the upper bound of the last bucket
1892        assert!(histogram.max_value().is_some());
1893    }
1894}