Skip to main content

datafusion_physical_plan/operator_statistics/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Pluggable statistics propagation for physical plans.
19//!
20//! This module provides an extensible mechanism for computing statistics
21//! on [`ExecutionPlan`] nodes, following the chain of responsibility pattern
22//! similar to `RelationPlanner` for SQL parsing.
23//!
24//! # Overview
25//!
26//! The default implementation delegates to each operator's built-in
27//! `partition_statistics`. Users can register custom [`StatisticsProvider`]
28//! implementations to:
29//!
30//! 1. Provide statistics for custom [`ExecutionPlan`] implementations
31//! 2. Override default estimation with advanced approaches (e.g., histograms)
32//! 3. Plug in domain-specific knowledge for better cardinality estimation
33//!
34//! # Architecture
35//!
36//! - [`StatisticsProvider`]: Chain element that computes statistics for specific operators
37//! - [`StatisticsRegistry`]: Chains providers, lives in SessionState
38//! - [`ExtendedStatistics`]: Statistics with type-safe custom extensions
39//!
40//! # Built-in Providers
41//!
42//! The following providers are included and can be registered in this order:
43//!
44//! 1. [`FilterStatisticsProvider`] - selectivity-based filter estimation
45//! 2. [`ProjectionStatisticsProvider`] - column mapping through projections
46//! 3. [`PassthroughStatisticsProvider`] - passthrough for cardinality-preserving operators
47//! 4. [`AggregateStatisticsProvider`] - NDV-based GROUP BY cardinality estimation
48//! 5. [`JoinStatisticsProvider`] - NDV-based join output estimation (hash, sort-merge, cross)
49//! 6. [`LimitStatisticsProvider`] - caps output at the fetch limit (local and global)
50//! 7. [`UnionStatisticsProvider`] - sums input row counts
51//! 8. [`DefaultStatisticsProvider`] - fallback to `partition_statistics(None)`
52//!
53//! # Relationship to [#20184](https://github.com/apache/datafusion/issues/20184)
54//!
55//! This module performs its own bottom-up tree walk in [`StatisticsRegistry::compute`],
56//! separate from the walk optimizer rules do via `transform_up`. This means existing
57//! rules that call `partition_statistics` directly bypass the registry.
58//!
59//! [#20184](https://github.com/apache/datafusion/issues/20184) adds a `child_stats`
60//! parameter to `partition_statistics`. Once it lands, the registry can feed enriched
61//! **base** [`Statistics`] into operators' built-in `partition_statistics` calls,
62//! removing redundancy for the base-stats path (row counts, column stats). However,
63//! the separate registry walk is still required for [`ExtendedStatistics`] extension
64//! propagation: `partition_statistics` returns `Arc<Statistics>`, so extensions
65//! (histograms, sketches, etc.) are stripped at that boundary and can only flow
66//! through the registry walk.
67//!
68//! If [`Statistics`] itself were extended to carry a type-erased extension map
69//! (similar to [`ExtendedStatistics`]), the registry walk could be dropped entirely:
70//! extensions would flow naturally through `partition_statistics(child_stats)` and
71//! the registry would become a pure chain-of-responsibility on top of the existing
72//! traversal with no separate walk needed.
73//!
74//! # Example
75//!
76//! ```ignore
77//! use datafusion_physical_plan::operator_statistics::*;
78//!
79//! // Create registry with default provider
80//! let mut registry = StatisticsRegistry::new();
81//!
82//! // Register custom provider (higher priority)
83//! registry.register(Arc::new(MyHistogramProvider));
84//!
85//! // Compute statistics through the chain
86//! let stats = registry.compute(plan.as_ref())?;
87//! ```
88
89use std::fmt::{self, Debug};
90use std::sync::Arc;
91
92use datafusion_common::extensions::Extensions;
93use datafusion_common::stats::Precision;
94use datafusion_common::{Result, Statistics};
95
96use crate::ExecutionPlan;
97
98// ============================================================================
99// ExtendedStatistics: Statistics with type-safe extensions
100// ============================================================================
101
102/// Statistics with support for custom extensions.
103///
104/// Wraps the standard [`Statistics`] and adds a type-erased extension map
105/// for custom statistics like histograms, sketches, or domain-specific metadata.
106///
107/// # Example
108///
109/// ```ignore
110/// // Define a custom statistics extension
111/// #[derive(Debug, Clone)]
112/// struct HistogramStats {
113///     buckets: Vec<(i64, i64, usize)>, // (min, max, count)
114/// }
115///
116/// // Set extension in a planner
117/// let mut stats = ExtendedStatistics::from(base_stats);
118/// stats.set_extension(HistogramStats { buckets: vec![] });
119///
120/// // Retrieve in a consumer
121/// if let Some(hist) = stats.get_extension::<HistogramStats>() {
122///     // Use histogram for better estimation
123/// }
124/// ```
125#[derive(Debug, Clone, Default)]
126pub struct ExtendedStatistics {
127    /// Standard statistics (num_rows, byte_size, column stats)
128    base: Arc<Statistics>,
129    /// Type-erased extensions for custom statistics
130    extensions: Extensions,
131}
132
133impl ExtendedStatistics {
134    /// Create new ExtendedStatistics wrapping owned statistics.
135    pub fn new(base: Statistics) -> Self {
136        Self {
137            base: Arc::new(base),
138            extensions: Extensions::new(),
139        }
140    }
141
142    /// Create new ExtendedStatistics from an [`Arc<Statistics>`].
143    pub fn new_arc(base: Arc<Statistics>) -> Self {
144        Self {
145            base,
146            extensions: Extensions::new(),
147        }
148    }
149
150    /// Returns a reference to the base [`Statistics`].
151    pub fn base(&self) -> &Statistics {
152        &self.base
153    }
154
155    /// Returns a reference to the underlying [`Arc<Statistics>`].
156    pub fn base_arc(&self) -> &Arc<Statistics> {
157        &self.base
158    }
159
160    /// Get a reference to a custom statistics extension by type.
161    pub fn get_extension<T: 'static + Send + Sync>(&self) -> Option<&T> {
162        self.extensions.get::<T>()
163    }
164
165    /// Set a custom statistics extension.
166    pub fn set_extension<T: 'static + Send + Sync>(&mut self, value: T) {
167        self.extensions.insert(value);
168    }
169
170    /// Check if an extension of the given type exists.
171    pub fn has_extension<T: 'static + Send + Sync>(&self) -> bool {
172        self.extensions.contains::<T>()
173    }
174
175    /// Merge extensions from another ExtendedStatistics (other's extensions take precedence).
176    pub fn merge_extensions(&mut self, other: &ExtendedStatistics) {
177        self.extensions.merge(&other.extensions);
178    }
179}
180
181impl From<Statistics> for ExtendedStatistics {
182    fn from(base: Statistics) -> Self {
183        Self::new(base)
184    }
185}
186
187impl From<Arc<Statistics>> for ExtendedStatistics {
188    fn from(base: Arc<Statistics>) -> Self {
189        Self::new_arc(base)
190    }
191}
192
193impl From<ExtendedStatistics> for Statistics {
194    fn from(extended: ExtendedStatistics) -> Self {
195        Arc::unwrap_or_clone(extended.base)
196    }
197}
198
199// ============================================================================
200// StatisticsProvider trait and registry
201// ============================================================================
202
203/// Result of attempting to compute statistics with a [`StatisticsProvider`].
204#[derive(Debug)]
205pub enum StatisticsResult {
206    /// Statistics were computed by this provider
207    Computed(ExtendedStatistics),
208    /// This provider doesn't handle this operator; delegate to next in chain
209    Delegate,
210}
211
212/// Customize statistics computation for [`ExecutionPlan`] nodes.
213///
214/// Implementations can handle specific operator types or override default
215/// estimation logic. The chain of providers is traversed until one returns
216/// [`StatisticsResult::Computed`].
217///
218/// # Implementing a Custom Provider
219///
220/// ```ignore
221/// #[derive(Debug)]
222/// struct MyStatisticsProvider;
223///
224/// impl StatisticsProvider for MyStatisticsProvider {
225///     fn compute_statistics(
226///         &self,
227///         plan: &dyn ExecutionPlan,
228///         child_stats: &[ExtendedStatistics],
229///     ) -> Result<StatisticsResult> {
230///         if let Some(my_exec) = plan.downcast_ref::<MyCustomExec>() {
231///             // Custom logic for MyCustomExec
232///             Ok(StatisticsResult::Computed(/* ... */))
233///         } else {
234///             // Let next provider handle it
235///             Ok(StatisticsResult::Delegate)
236///         }
237///     }
238/// }
239/// ```
240pub trait StatisticsProvider: Debug + Send + Sync {
241    /// Compute statistics for an [`ExecutionPlan`] node.
242    ///
243    /// # Arguments
244    /// * `plan` - The execution plan node to compute statistics for
245    /// * `child_stats` - Extended statistics already computed for child nodes,
246    ///   in the same order as `plan.children()`. Empty for leaf nodes.
247    ///
248    /// # Returns
249    /// * `StatisticsResult::Computed(stats)` - Short-circuits the chain
250    /// * `StatisticsResult::Delegate` - Passes to next provider in chain
251    fn compute_statistics(
252        &self,
253        plan: &dyn ExecutionPlan,
254        child_stats: &[ExtendedStatistics],
255    ) -> Result<StatisticsResult>;
256}
257
258/// Default statistics provider that delegates to each operator's built-in
259/// `partition_statistics` implementation.
260#[derive(Debug, Default)]
261pub struct DefaultStatisticsProvider;
262
263impl StatisticsProvider for DefaultStatisticsProvider {
264    fn compute_statistics(
265        &self,
266        plan: &dyn ExecutionPlan,
267        _child_stats: &[ExtendedStatistics],
268    ) -> Result<StatisticsResult> {
269        let base = plan.partition_statistics(None)?;
270        Ok(StatisticsResult::Computed(ExtendedStatistics::new_arc(
271            base,
272        )))
273    }
274}
275
276/// Registry that chains [`StatisticsProvider`] implementations.
277///
278/// The registry is a stateless provider chain: it holds no mutable state
279/// and is cheaply `Clone`able / `Send` / `Sync`.
280#[derive(Clone)]
281pub struct StatisticsRegistry {
282    providers: Vec<Arc<dyn StatisticsProvider>>,
283}
284
285impl Debug for StatisticsRegistry {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        write!(f, "StatisticsRegistry({} providers)", self.providers.len())
288    }
289}
290
291impl Default for StatisticsRegistry {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297impl StatisticsRegistry {
298    /// Create a new empty registry.
299    ///
300    /// With no providers, `compute()` falls back to each plan node's
301    /// built-in `partition_statistics()`. Register providers to enhance
302    /// statistics (e.g., inject NDV, use histograms).
303    pub fn new() -> Self {
304        Self {
305            providers: Vec::new(),
306        }
307    }
308
309    /// Create a registry with the given provider chain.
310    pub fn with_providers(providers: Vec<Arc<dyn StatisticsProvider>>) -> Self {
311        Self { providers }
312    }
313
314    /// Create a registry pre-loaded with the standard built-in providers.
315    ///
316    /// Provider order (first match wins):
317    /// 1. [`FilterStatisticsProvider`]
318    /// 2. [`ProjectionStatisticsProvider`]
319    /// 3. [`PassthroughStatisticsProvider`]
320    /// 4. [`AggregateStatisticsProvider`]
321    /// 5. [`JoinStatisticsProvider`]
322    /// 6. [`LimitStatisticsProvider`]
323    /// 7. [`UnionStatisticsProvider`]
324    /// 8. [`DefaultStatisticsProvider`]
325    pub fn default_with_builtin_providers() -> Self {
326        Self::with_providers(vec![
327            Arc::new(FilterStatisticsProvider),
328            Arc::new(ProjectionStatisticsProvider),
329            Arc::new(PassthroughStatisticsProvider),
330            Arc::new(AggregateStatisticsProvider),
331            Arc::new(JoinStatisticsProvider),
332            Arc::new(LimitStatisticsProvider),
333            Arc::new(UnionStatisticsProvider),
334            Arc::new(DefaultStatisticsProvider),
335        ])
336    }
337
338    /// Register a provider at the front of the chain (higher priority).
339    pub fn register(&mut self, provider: Arc<dyn StatisticsProvider>) {
340        self.providers.insert(0, provider);
341    }
342
343    /// Returns the current provider chain.
344    pub fn providers(&self) -> &[Arc<dyn StatisticsProvider>] {
345        &self.providers
346    }
347
348    /// Compute extended statistics for a plan through the provider chain.
349    ///
350    /// Performs a bottom-up tree walk: child statistics are computed recursively
351    /// and passed to providers, mirroring how `partition_statistics` composes
352    /// operators. Once [#20184](https://github.com/apache/datafusion/issues/20184)
353    /// lands, the registry can feed enriched base stats directly into
354    /// `partition_statistics(child_stats)`, removing the need for a separate walk.
355    ///
356    /// If no providers are registered, falls back to the plan's built-in
357    /// `partition_statistics(None)` with no overhead.
358    pub fn compute(&self, plan: &dyn ExecutionPlan) -> Result<ExtendedStatistics> {
359        // Fast path: no providers registered, skip the walk entirely
360        if self.providers.is_empty() {
361            let base = plan.partition_statistics(None)?;
362            return Ok(ExtendedStatistics::new_arc(base));
363        }
364
365        let children = plan.children();
366
367        // For leaf nodes, try providers with empty child stats.
368        // For non-leaf nodes, recursively compute enhanced child stats first.
369        let child_stats: Vec<ExtendedStatistics> = if children.is_empty() {
370            Vec::new()
371        } else {
372            children
373                .iter()
374                .map(|child| self.compute(child.as_ref()))
375                .collect::<Result<Vec<_>>>()?
376        };
377
378        for provider in &self.providers {
379            match provider.compute_statistics(plan, &child_stats)? {
380                StatisticsResult::Computed(stats) => return Ok(stats),
381                StatisticsResult::Delegate => continue,
382            }
383        }
384        // Fallback: use plan's built-in stats
385        let base = plan.partition_statistics(None)?;
386        Ok(ExtendedStatistics::new_arc(base))
387    }
388
389    /// Compute statistics and return only the base Statistics (no extensions).
390    ///
391    /// Convenience method for callers that don't need extensions.
392    pub fn compute_base(&self, plan: &dyn ExecutionPlan) -> Result<Statistics> {
393        Ok(self.compute(plan)?.base().clone())
394    }
395}
396
397// ============================================================================
398// Statistics Utility Functions
399// ============================================================================
400
401/// Estimate the number of distinct values when sampling from a population.
402///
403/// Given a domain with `domain_size` distinct values and `num_selected` rows
404/// sampled/filtered from it, estimates how many distinct values will appear
405/// in the sample.
406///
407/// Uses the formula: `Expected distinct = N * [1 - (1 - 1/N)^n]`
408///
409/// # References
410///
411/// Based on Calcite's `RelMdUtil.numDistinctVals()`:
412/// <https://github.com/apache/calcite/blob/main/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java>
413pub fn num_distinct_vals(domain_size: usize, num_selected: usize) -> usize {
414    if domain_size == 0 || num_selected == 0 {
415        return 0;
416    }
417
418    if num_selected >= domain_size {
419        return domain_size;
420    }
421
422    let n = domain_size as f64;
423    let k = num_selected as f64;
424
425    // For large n, (1-1/n).powf(k) loses precision because the base is near
426    // 1.0; use the equivalent exp(-k/n) form which is numerically stable.
427    // Threshold matches Calcite's RelMdUtil.numDistinctVals().
428    let expected = if domain_size > 1000 {
429        n * (1.0 - (-k / n).exp())
430    } else {
431        n * (1.0 - (1.0 - 1.0 / n).powf(k))
432    };
433
434    let result = expected.round() as usize;
435    result.clamp(1, domain_size)
436}
437
438/// Estimate NDV after applying a selectivity factor (filtering).
439///
440/// When filtering rows, each distinct value has multiple rows. If a value
441/// appears `k` times, the probability it survives the filter is `1 - (1-s)^k`
442/// where `s` is the selectivity.
443///
444/// Assuming uniform distribution (each value appears `rows/ndv` times):
445/// ```text
446/// NDV_after ~ NDV_before * [1 - (1 - selectivity)^(rows/NDV)]
447/// ```
448pub fn ndv_after_selectivity(
449    original_ndv: usize,
450    original_rows: usize,
451    selectivity: f64,
452) -> usize {
453    if selectivity <= 0.0 || original_ndv == 0 || original_rows == 0 {
454        return 0;
455    }
456    if selectivity >= 1.0 {
457        return original_ndv;
458    }
459
460    let ndv = original_ndv as f64;
461    let rows = original_rows as f64;
462
463    let rows_per_value = rows / ndv;
464    let survival_prob = 1.0 - (1.0 - selectivity).powf(rows_per_value);
465    let expected_ndv = ndv * survival_prob;
466
467    (expected_ndv.round() as usize).clamp(1, original_ndv)
468}
469
470/// Rescale `total_byte_size` proportionally after overriding `num_rows`.
471///
472/// When a provider replaces `num_rows` but keeps the rest of the stats from
473/// `partition_statistics`, the original `total_byte_size` becomes inconsistent.
474/// This function adjusts it by the ratio `new_rows / old_rows`, preserving the
475/// average bytes-per-row from the original estimate.
476fn rescale_byte_size(stats: &mut Statistics, new_num_rows: Precision<usize>) {
477    let old_rows = stats.num_rows;
478    stats.num_rows = new_num_rows;
479    stats.total_byte_size = match (old_rows, new_num_rows, stats.total_byte_size) {
480        (Precision::Exact(old), Precision::Exact(new), Precision::Exact(bytes))
481            if old > 0 =>
482        {
483            Precision::Exact((bytes as f64 * new as f64 / old as f64).round() as usize)
484        }
485        _ => match (
486            old_rows.get_value(),
487            new_num_rows.get_value(),
488            stats.total_byte_size.get_value(),
489        ) {
490            (Some(&old), Some(&new), Some(&bytes)) if old > 0 => Precision::Inexact(
491                (bytes as f64 * new as f64 / old as f64).round() as usize,
492            ),
493            _ => stats.total_byte_size,
494        },
495    };
496}
497
498/// Fetches base statistics from the operator's built-in `partition_statistics`,
499/// overrides `num_rows` with the registry-computed estimate, and rescales
500/// `total_byte_size` proportionally.
501///
502/// Used by providers that compute a better row count but cannot yet propagate
503/// column-level stats (NDV, min/max) through the operator — pending #20184.
504fn computed_with_row_count(
505    plan: &dyn ExecutionPlan,
506    num_rows: Precision<usize>,
507) -> Result<StatisticsResult> {
508    let mut base = Arc::unwrap_or_clone(plan.partition_statistics(None)?);
509    rescale_byte_size(&mut base, num_rows);
510    Ok(StatisticsResult::Computed(ExtendedStatistics::new(base)))
511}
512
513/// Statistics provider for [`FilterExec`](crate::filter::FilterExec) that uses
514/// pre-computed enhanced child statistics from the registry walk.
515///
516/// Unlike the default provider (which calls `partition_statistics` and gets raw
517/// child stats), this provider receives enhanced child stats that may include
518/// NDV overrides injected at the scan level. It applies the same selectivity
519/// estimation logic as `FilterExec::statistics_helper`, then additionally
520/// adjusts each column's `distinct_count` using [`ndv_after_selectivity`] based
521/// on the computed selectivity ratio.
522#[derive(Debug, Default)]
523pub struct FilterStatisticsProvider;
524
525impl StatisticsProvider for FilterStatisticsProvider {
526    fn compute_statistics(
527        &self,
528        plan: &dyn ExecutionPlan,
529        child_stats: &[ExtendedStatistics],
530    ) -> Result<StatisticsResult> {
531        use crate::filter::FilterExec;
532
533        let Some(filter) = plan.downcast_ref::<FilterExec>() else {
534            return Ok(StatisticsResult::Delegate);
535        };
536        if child_stats.is_empty() {
537            return Ok(StatisticsResult::Delegate);
538        }
539
540        let input_stats = (*child_stats[0].base).clone();
541        let input_rows = input_stats.num_rows;
542        let mut stats = FilterExec::statistics_helper(
543            &filter.input().schema(),
544            input_stats,
545            filter.predicate(),
546            filter.default_selectivity(),
547            // TODO: pass filter.expression_analyzer_registry() once #21122 lands
548        )?;
549
550        // Adjust distinct_count for each column using the selectivity ratio
551        // via the probabilistic survival model from
552        // ndv_after_selectivity to account for rows removed by the filter.
553        if let (Some(&orig_rows), Some(&filtered_rows)) =
554            (input_rows.get_value(), stats.num_rows.get_value())
555            && orig_rows > 0
556            && filtered_rows < orig_rows
557        {
558            let selectivity = filtered_rows as f64 / orig_rows as f64;
559            for col_stat in &mut stats.column_statistics {
560                if let Some(&ndv) = col_stat.distinct_count.get_value() {
561                    let adjusted = ndv_after_selectivity(ndv, orig_rows, selectivity);
562                    col_stat.distinct_count = Precision::Inexact(adjusted);
563                }
564            }
565        }
566
567        let stats = stats.project(filter.projection().as_ref());
568        Ok(StatisticsResult::Computed(ExtendedStatistics::new(stats)))
569    }
570}
571
572/// Statistics provider for [`ProjectionExec`](crate::projection::ProjectionExec)
573/// that uses pre-computed enhanced child statistics from the registry walk.
574///
575/// Maps enhanced child column statistics to output columns based on the
576/// projection expressions, preserving NDV and other statistics through
577/// column references.
578#[derive(Debug, Default)]
579pub struct ProjectionStatisticsProvider;
580
581impl StatisticsProvider for ProjectionStatisticsProvider {
582    fn compute_statistics(
583        &self,
584        plan: &dyn ExecutionPlan,
585        child_stats: &[ExtendedStatistics],
586    ) -> Result<StatisticsResult> {
587        use crate::projection::ProjectionExec;
588
589        let Some(proj) = plan.downcast_ref::<ProjectionExec>() else {
590            return Ok(StatisticsResult::Delegate);
591        };
592        if child_stats.is_empty() {
593            return Ok(StatisticsResult::Delegate);
594        }
595
596        let input_stats = (*child_stats[0].base).clone();
597        let output_schema = proj.schema();
598        // TODO: pass proj.expression_analyzer_registry() once #21122 lands,
599        // so expression-level NDV/min/max feeds into projected column stats.
600        let stats = proj
601            .projection_expr()
602            .project_statistics(input_stats, &output_schema)?;
603        Ok(StatisticsResult::Computed(ExtendedStatistics::new(stats)))
604    }
605}
606
607/// Statistics provider for single-input operators with
608/// [`CardinalityEffect::Equal`](crate::execution_plan::CardinalityEffect::Equal).
609///
610/// These operators (Sort, Repartition, CoalescePartitions, etc.) don't
611/// transform statistics, so we pass through the enhanced child stats directly.
612/// This avoids the fallback calling `partition_statistics(None)` which would
613/// trigger a redundant internal recursion with raw (non-enhanced) stats.
614#[derive(Debug, Default)]
615pub struct PassthroughStatisticsProvider;
616
617impl StatisticsProvider for PassthroughStatisticsProvider {
618    fn compute_statistics(
619        &self,
620        plan: &dyn ExecutionPlan,
621        child_stats: &[ExtendedStatistics],
622    ) -> Result<StatisticsResult> {
623        use crate::execution_plan::CardinalityEffect;
624
625        if child_stats.len() != 1
626            || !matches!(plan.cardinality_effect(), CardinalityEffect::Equal)
627        {
628            return Ok(StatisticsResult::Delegate);
629        }
630
631        // Only pass through when the schema is unchanged (same column count).
632        // Operators like WindowAggExec preserve row count but add columns;
633        // passing through child stats would produce wrong column_statistics.
634        let input_cols = child_stats[0].base.column_statistics.len();
635        let output_cols = plan.schema().fields().len();
636        if input_cols != output_cols {
637            return Ok(StatisticsResult::Delegate);
638        }
639
640        Ok(StatisticsResult::Computed(child_stats[0].clone()))
641    }
642}
643
644/// Statistics provider for [`AggregateExec`](crate::aggregates::AggregateExec)
645/// that estimates output cardinality from the NDV of GROUP BY columns.
646///
647/// For each GROUP BY column, looks up `distinct_count` from the enhanced
648/// child statistics. The estimated output rows is the product of all
649/// column NDVs, capped at the input row count. This assumes independence
650/// between columns, so correlated columns (e.g., `city` and `state`) will
651/// produce overestimates.
652///
653/// For GROUPING SETS / CUBE / ROLLUP, delegates to the built-in
654/// `partition_statistics`, which handles per-set NDV estimation correctly.
655///
656/// Delegates when:
657/// - The plan is not an `AggregateExec`
658/// - The aggregate is `Partial` (per-partition, not bounded by global NDV)
659/// - GROUP BY is empty (scalar aggregate)
660/// - Any GROUP BY expression is not a simple column reference
661/// - Any GROUP BY column lacks NDV information
662#[derive(Debug, Default)]
663pub struct AggregateStatisticsProvider;
664
665impl StatisticsProvider for AggregateStatisticsProvider {
666    fn compute_statistics(
667        &self,
668        plan: &dyn ExecutionPlan,
669        child_stats: &[ExtendedStatistics],
670    ) -> Result<StatisticsResult> {
671        use crate::aggregates::AggregateExec;
672        use datafusion_physical_expr::expressions::Column;
673
674        use crate::aggregates::AggregateMode;
675
676        let Some(agg) = plan.downcast_ref::<AggregateExec>() else {
677            return Ok(StatisticsResult::Delegate);
678        };
679
680        // Partial aggregates produce per-partition groups, not bounded by
681        // global NDV; delegate to the built-in estimate for those.
682        if matches!(agg.mode(), AggregateMode::Partial) {
683            return Ok(StatisticsResult::Delegate);
684        }
685
686        if child_stats.is_empty() || agg.group_expr().expr().is_empty() {
687            return Ok(StatisticsResult::Delegate);
688        }
689
690        let input_stats = &child_stats[0].base;
691
692        // Compute NDV product of GROUP BY columns
693        let mut ndv_product: Option<usize> = None;
694        for (expr, _) in agg.group_expr().expr().iter() {
695            let Some(col) = expr.downcast_ref::<Column>() else {
696                return Ok(StatisticsResult::Delegate);
697            };
698            let Some(&ndv) = input_stats
699                .column_statistics
700                .get(col.index())
701                .and_then(|s| s.distinct_count.get_value())
702            else {
703                return Ok(StatisticsResult::Delegate);
704            };
705            if ndv == 0 {
706                return Ok(StatisticsResult::Delegate);
707            }
708            ndv_product = Some(match ndv_product {
709                Some(prev) => prev.saturating_mul(ndv),
710                None => ndv,
711            });
712        }
713
714        let Some(product) = ndv_product else {
715            return Ok(StatisticsResult::Delegate);
716        };
717
718        // For CUBE/ROLLUP/GROUPING SETS (multiple grouping sets), delegate to
719        // the built-in estimate, which handles per-set NDV estimation correctly.
720        if agg.group_expr().groups().len() > 1 {
721            return Ok(StatisticsResult::Delegate);
722        }
723
724        // Cap at input rows
725        let estimate = match input_stats.num_rows.get_value() {
726            Some(&rows) => product.min(rows),
727            None => product,
728        };
729
730        let num_rows = Precision::Inexact(estimate);
731
732        computed_with_row_count(plan, num_rows)
733    }
734}
735
736/// Statistics provider for equi-joins (hash join, sort-merge join) and cross joins.
737///
738/// For equi-joins, estimates output cardinality as
739/// `left_rows * right_rows / product(max(left_ndv_i, right_ndv_i))`
740/// across all join key columns (assuming independence between keys),
741/// falling back to the Cartesian product when any key lacks NDV on both sides.
742/// For cross joins, uses the exact Cartesian product.
743///
744/// The base inner-join estimate is then adjusted for the join type:
745/// - Semi joins: capped at the preserved-side row count
746/// - Anti joins: preserved-side minus matched rows (clamped to 0)
747/// - Left/Right outer: at least as many rows as the preserved side
748/// - Full outer: at least `left + right - inner_estimate`
749/// - Left mark: exactly `left_rows` (one output row per left row)
750///
751/// Delegates when:
752/// - The plan is not a supported join type
753/// - Either input lacks row count information
754#[derive(Debug, Default)]
755pub struct JoinStatisticsProvider;
756
757impl StatisticsProvider for JoinStatisticsProvider {
758    fn compute_statistics(
759        &self,
760        plan: &dyn ExecutionPlan,
761        child_stats: &[ExtendedStatistics],
762    ) -> Result<StatisticsResult> {
763        use crate::joins::{CrossJoinExec, HashJoinExec, SortMergeJoinExec};
764        use datafusion_common::JoinType;
765        use datafusion_physical_expr::expressions::Column;
766
767        if child_stats.len() < 2 {
768            return Ok(StatisticsResult::Delegate);
769        }
770
771        let left = &child_stats[0].base;
772        let right = &child_stats[1].base;
773
774        let (Some(&left_rows), Some(&right_rows)) =
775            (left.num_rows.get_value(), right.num_rows.get_value())
776        else {
777            return Ok(StatisticsResult::Delegate);
778        };
779
780        use crate::joins::JoinOnRef;
781
782        /// Estimate equi-join output using NDV of join key columns:
783        ///   left_rows * right_rows / product(max(left_ndv_i, right_ndv_i))
784        /// Falls back to Cartesian product if any key lacks NDV on both sides.
785        fn equi_join_estimate(
786            on: JoinOnRef,
787            left: &Statistics,
788            right: &Statistics,
789            left_rows: usize,
790            right_rows: usize,
791        ) -> usize {
792            if on.is_empty() {
793                return left_rows.saturating_mul(right_rows);
794            }
795            let mut ndv_divisor: usize = 1;
796            for (left_key, right_key) in on {
797                let left_ndv = left_key
798                    .downcast_ref::<Column>()
799                    .and_then(|c| left.column_statistics.get(c.index()))
800                    .and_then(|s| s.distinct_count.get_value().copied());
801                let right_ndv = right_key
802                    .downcast_ref::<Column>()
803                    .and_then(|c| right.column_statistics.get(c.index()))
804                    .and_then(|s| s.distinct_count.get_value().copied());
805                match (left_ndv, right_ndv) {
806                    (Some(l), Some(r)) if l > 0 && r > 0 => {
807                        ndv_divisor = ndv_divisor.saturating_mul(l.max(r));
808                    }
809                    _ => return left_rows.saturating_mul(right_rows),
810                }
811            }
812            let max_rows = left_rows.saturating_mul(right_rows);
813            max_rows.checked_div(ndv_divisor).unwrap_or(max_rows)
814        }
815
816        let (inner_estimate, is_exact_cartesian, join_type) = if let Some(hash_join) =
817            plan.downcast_ref::<HashJoinExec>()
818        {
819            let est =
820                equi_join_estimate(hash_join.on(), left, right, left_rows, right_rows);
821            (est, false, *hash_join.join_type())
822        } else if let Some(smj) = plan.downcast_ref::<SortMergeJoinExec>() {
823            let est = equi_join_estimate(smj.on(), left, right, left_rows, right_rows);
824            (est, false, smj.join_type())
825        } else if plan.downcast_ref::<CrossJoinExec>().is_some() {
826            let both_exact = left.num_rows.is_exact().unwrap_or(false)
827                && right.num_rows.is_exact().unwrap_or(false);
828            (
829                left_rows.saturating_mul(right_rows),
830                both_exact,
831                JoinType::Inner,
832            )
833        } else {
834            return Ok(StatisticsResult::Delegate);
835        };
836
837        // Apply join-type-aware cardinality bounds
838        let estimated = match join_type {
839            JoinType::Inner => inner_estimate,
840            JoinType::Left => inner_estimate.max(left_rows),
841            JoinType::Right => inner_estimate.max(right_rows),
842            JoinType::Full => {
843                // At least left + right - matched, but never less than inner
844                let outer_bound = left_rows
845                    .saturating_add(right_rows)
846                    .saturating_sub(inner_estimate);
847                inner_estimate.max(outer_bound)
848            }
849            JoinType::LeftSemi => inner_estimate.min(left_rows),
850            JoinType::RightSemi => inner_estimate.min(right_rows),
851            JoinType::LeftAnti => left_rows.saturating_sub(inner_estimate.min(left_rows)),
852            JoinType::RightAnti => {
853                right_rows.saturating_sub(inner_estimate.min(right_rows))
854            }
855            JoinType::LeftMark => left_rows,
856            JoinType::RightMark => right_rows,
857        };
858
859        // NL join inner with exact inputs is an exact Cartesian product;
860        // NDV-based estimates are inherently inexact.
861        let num_rows = if is_exact_cartesian && join_type == JoinType::Inner {
862            Precision::Exact(estimated)
863        } else {
864            Precision::Inexact(estimated)
865        };
866
867        computed_with_row_count(plan, num_rows)
868    }
869}
870
871/// Statistics provider for [`LocalLimitExec`](crate::limit::LocalLimitExec) and
872/// [`GlobalLimitExec`](crate::limit::GlobalLimitExec).
873///
874/// Caps output row count at the limit value, accounting for any leading skip offset
875/// in `GlobalLimitExec`.
876#[derive(Debug, Default)]
877pub struct LimitStatisticsProvider;
878
879impl StatisticsProvider for LimitStatisticsProvider {
880    fn compute_statistics(
881        &self,
882        plan: &dyn ExecutionPlan,
883        child_stats: &[ExtendedStatistics],
884    ) -> Result<StatisticsResult> {
885        use crate::limit::{GlobalLimitExec, LocalLimitExec};
886
887        if child_stats.is_empty() {
888            return Ok(StatisticsResult::Delegate);
889        }
890
891        let (skip, fetch) = if let Some(limit) = plan.downcast_ref::<LocalLimitExec>() {
892            (0usize, Some(limit.fetch()))
893        } else if let Some(limit) = plan.downcast_ref::<GlobalLimitExec>() {
894            (limit.skip(), limit.fetch())
895        } else {
896            return Ok(StatisticsResult::Delegate);
897        };
898
899        let num_rows = match child_stats[0].base.num_rows {
900            Precision::Exact(rows) => {
901                let available = rows.saturating_sub(skip);
902                Precision::Exact(fetch.map_or(available, |f| available.min(f)))
903            }
904            Precision::Inexact(rows) => {
905                let available = rows.saturating_sub(skip);
906                match fetch {
907                    Some(f) => Precision::Inexact(available.min(f)),
908                    None => Precision::Inexact(available),
909                }
910            }
911            Precision::Absent => match fetch {
912                Some(f) => Precision::Inexact(f),
913                None => Precision::Absent,
914            },
915        };
916
917        computed_with_row_count(plan, num_rows)
918    }
919}
920
921/// Statistics provider for [`UnionExec`](crate::union::UnionExec).
922///
923/// Sums row counts across all inputs.
924#[derive(Debug, Default)]
925pub struct UnionStatisticsProvider;
926
927impl StatisticsProvider for UnionStatisticsProvider {
928    fn compute_statistics(
929        &self,
930        plan: &dyn ExecutionPlan,
931        child_stats: &[ExtendedStatistics],
932    ) -> Result<StatisticsResult> {
933        use crate::union::UnionExec;
934
935        if plan.downcast_ref::<UnionExec>().is_none() {
936            return Ok(StatisticsResult::Delegate);
937        }
938
939        let total = child_stats.iter().try_fold(
940            Precision::Exact(0usize),
941            |acc, s| -> Result<Precision<usize>> {
942                Ok(match (acc, s.base.num_rows) {
943                    (Precision::Absent, _) | (_, Precision::Absent) => Precision::Absent,
944                    (Precision::Exact(a), Precision::Exact(b)) => {
945                        Precision::Exact(a.saturating_add(b))
946                    }
947                    (Precision::Inexact(a), Precision::Exact(b))
948                    | (Precision::Exact(a), Precision::Inexact(b))
949                    | (Precision::Inexact(a), Precision::Inexact(b)) => {
950                        Precision::Inexact(a.saturating_add(b))
951                    }
952                })
953            },
954        )?;
955
956        computed_with_row_count(plan, total)
957    }
958}
959
960type ProviderFn = dyn Fn(&dyn ExecutionPlan, &[ExtendedStatistics]) -> Result<StatisticsResult>
961    + Send
962    + Sync;
963
964/// A [`StatisticsProvider`] backed by a user-supplied closure.
965///
966/// Useful for injecting custom statistics in tests or for cardinality feedback
967/// pipelines where real runtime statistics need to override plan estimates.
968/// The closure receives the current plan node and its children's enhanced
969/// statistics, returning a [`StatisticsResult`].
970///
971/// To distinguish between multiple nodes of the same type (e.g., two
972/// `FilterExec` nodes), match on structural properties like the input schema's
973/// column names, number of columns, or child row counts.
974///
975/// # Example
976///
977/// ```rust,ignore (requires crate-internal imports)
978/// let provider = ClosureStatisticsProvider::new(|plan, child_stats| {
979///     if plan.downcast_ref::<FilterExec>().is_some() {
980///         Ok(StatisticsResult::Computed(ExtendedStatistics::from(Statistics {
981///             num_rows: Precision::Inexact(42),
982///             ..Statistics::new_unknown(plan.schema().as_ref())
983///         })))
984///     } else {
985///         Ok(StatisticsResult::Delegate)
986///     }
987/// });
988/// ```
989pub struct ClosureStatisticsProvider {
990    f: Box<ProviderFn>,
991}
992
993impl ClosureStatisticsProvider {
994    /// Create a new provider from a closure.
995    pub fn new(
996        f: impl Fn(&dyn ExecutionPlan, &[ExtendedStatistics]) -> Result<StatisticsResult>
997        + Send
998        + Sync
999        + 'static,
1000    ) -> Self {
1001        Self { f: Box::new(f) }
1002    }
1003}
1004
1005impl Debug for ClosureStatisticsProvider {
1006    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1007        write!(f, "ClosureStatisticsProvider")
1008    }
1009}
1010
1011impl StatisticsProvider for ClosureStatisticsProvider {
1012    fn compute_statistics(
1013        &self,
1014        plan: &dyn ExecutionPlan,
1015        child_stats: &[ExtendedStatistics],
1016    ) -> Result<StatisticsResult> {
1017        (self.f)(plan, child_stats)
1018    }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024    use crate::filter::FilterExec;
1025    use crate::projection::ProjectionExec;
1026    use crate::{DisplayAs, DisplayFormatType, PlanProperties};
1027    use arrow::datatypes::{DataType, Field, Schema};
1028    use datafusion_common::stats::Precision;
1029    use datafusion_common::{ColumnStatistics, ScalarValue};
1030    use datafusion_expr::Operator;
1031    use datafusion_physical_expr::PhysicalExpr;
1032    use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, col, lit};
1033    use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
1034    use std::fmt;
1035
1036    use crate::execution_plan::{Boundedness, EmissionType};
1037
1038    fn make_schema() -> Arc<Schema> {
1039        Arc::new(Schema::new(vec![
1040            Field::new("a", DataType::Int32, false),
1041            Field::new("b", DataType::Int32, false),
1042        ]))
1043    }
1044
1045    #[derive(Debug)]
1046    struct MockSourceExec {
1047        schema: Arc<Schema>,
1048        stats: Statistics,
1049        cache: Arc<PlanProperties>,
1050    }
1051
1052    impl MockSourceExec {
1053        fn new(schema: Arc<Schema>, num_rows: Precision<usize>) -> Self {
1054            let num_cols = schema.fields().len();
1055            Self::with_column_stats(
1056                schema,
1057                num_rows,
1058                vec![ColumnStatistics::new_unknown(); num_cols],
1059            )
1060        }
1061
1062        fn with_column_stats(
1063            schema: Arc<Schema>,
1064            num_rows: Precision<usize>,
1065            column_statistics: Vec<ColumnStatistics>,
1066        ) -> Self {
1067            let eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
1068            let cache = Arc::new(PlanProperties::new(
1069                eq_properties,
1070                Partitioning::UnknownPartitioning(1),
1071                EmissionType::Incremental,
1072                Boundedness::Bounded,
1073            ));
1074            Self {
1075                schema,
1076                stats: Statistics {
1077                    num_rows,
1078                    total_byte_size: Precision::Absent,
1079                    column_statistics,
1080                },
1081                cache,
1082            }
1083        }
1084    }
1085
1086    impl DisplayAs for MockSourceExec {
1087        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
1088            write!(f, "MockSourceExec")
1089        }
1090    }
1091
1092    impl ExecutionPlan for MockSourceExec {
1093        fn name(&self) -> &str {
1094            "MockSourceExec"
1095        }
1096
1097        fn schema(&self) -> Arc<Schema> {
1098            Arc::clone(&self.schema)
1099        }
1100
1101        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1102            vec![]
1103        }
1104
1105        fn with_new_children(
1106            self: Arc<Self>,
1107            _children: Vec<Arc<dyn ExecutionPlan>>,
1108        ) -> Result<Arc<dyn ExecutionPlan>> {
1109            Ok(self)
1110        }
1111
1112        fn properties(&self) -> &Arc<PlanProperties> {
1113            &self.cache
1114        }
1115
1116        fn execute(
1117            &self,
1118            _partition: usize,
1119            _context: Arc<datafusion_execution::TaskContext>,
1120        ) -> Result<crate::SendableRecordBatchStream> {
1121            unimplemented!()
1122        }
1123
1124        fn partition_statistics(
1125            &self,
1126            _partition: Option<usize>,
1127        ) -> Result<Arc<Statistics>> {
1128            Ok(Arc::new(self.stats.clone()))
1129        }
1130    }
1131
1132    fn make_source(num_rows: usize) -> Arc<dyn ExecutionPlan> {
1133        Arc::new(MockSourceExec::new(
1134            make_schema(),
1135            Precision::Exact(num_rows),
1136        ))
1137    }
1138
1139    #[test]
1140    fn test_default_provider() -> Result<()> {
1141        let engine = StatisticsRegistry::new();
1142        let source = make_source(1000);
1143
1144        let stats = engine.compute(source.as_ref())?;
1145        assert!(matches!(stats.base.num_rows, Precision::Exact(1000)));
1146        Ok(())
1147    }
1148
1149    #[test]
1150    fn test_custom_chain_configuration() -> Result<()> {
1151        let source = make_source(1000);
1152
1153        // Test with_providers: fully custom chain (no default)
1154        let custom_only =
1155            StatisticsRegistry::with_providers(vec![Arc::new(CustomStatisticsProvider)]);
1156        // CustomStatisticsProvider only handles CustomExec, delegates for others
1157        // With no default provider, filter returns fallback statistics
1158        let filter: Arc<dyn ExecutionPlan> =
1159            Arc::new(FilterExec::try_new(lit(true), Arc::clone(&source))?);
1160        let stats = custom_only.compute(filter.as_ref())?;
1161        // Falls back to plan.statistics() since no provider handles it
1162        assert!(stats.base.num_rows.get_value().is_some());
1163
1164        // Test with_providers: custom provider + built-in fallback
1165        let with_override =
1166            StatisticsRegistry::with_providers(vec![Arc::new(OverrideFilterProvider {
1167                fixed_selectivity: 0.25,
1168            })
1169                as Arc<dyn StatisticsProvider>]);
1170        // OverrideFilterProvider handles filters, built-in fallback handles the rest
1171        let stats = with_override.compute(filter.as_ref())?;
1172        assert!(matches!(stats.base.num_rows, Precision::Inexact(250)));
1173
1174        // Verify chain inspection
1175        assert_eq!(with_override.providers().len(), 1);
1176
1177        Ok(())
1178    }
1179
1180    #[derive(Debug)]
1181    struct CustomExec {
1182        input: Arc<dyn ExecutionPlan>,
1183    }
1184
1185    impl DisplayAs for CustomExec {
1186        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
1187            write!(f, "CustomExec")
1188        }
1189    }
1190
1191    impl ExecutionPlan for CustomExec {
1192        fn name(&self) -> &str {
1193            "CustomExec"
1194        }
1195
1196        fn schema(&self) -> Arc<Schema> {
1197            self.input.schema()
1198        }
1199
1200        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1201            vec![&self.input]
1202        }
1203
1204        fn with_new_children(
1205            self: Arc<Self>,
1206            children: Vec<Arc<dyn ExecutionPlan>>,
1207        ) -> Result<Arc<dyn ExecutionPlan>> {
1208            Ok(Arc::new(CustomExec {
1209                input: Arc::clone(&children[0]),
1210            }))
1211        }
1212
1213        fn properties(&self) -> &Arc<PlanProperties> {
1214            self.input.properties()
1215        }
1216
1217        fn execute(
1218            &self,
1219            _partition: usize,
1220            _context: Arc<datafusion_execution::TaskContext>,
1221        ) -> Result<crate::SendableRecordBatchStream> {
1222            unimplemented!()
1223        }
1224    }
1225
1226    #[derive(Debug)]
1227    struct CustomStatisticsProvider;
1228
1229    impl StatisticsProvider for CustomStatisticsProvider {
1230        fn compute_statistics(
1231            &self,
1232            plan: &dyn ExecutionPlan,
1233            child_stats: &[ExtendedStatistics],
1234        ) -> Result<StatisticsResult> {
1235            if plan.downcast_ref::<CustomExec>().is_some() {
1236                Ok(StatisticsResult::Computed(child_stats[0].clone()))
1237            } else {
1238                Ok(StatisticsResult::Delegate)
1239            }
1240        }
1241    }
1242
1243    #[test]
1244    fn test_custom_provider_for_custom_exec() -> Result<()> {
1245        let mut engine = StatisticsRegistry::new();
1246        engine.register(Arc::new(CustomStatisticsProvider));
1247
1248        let source = make_source(1000);
1249        let custom: Arc<dyn ExecutionPlan> = Arc::new(CustomExec { input: source });
1250
1251        let stats = engine.compute(custom.as_ref())?;
1252        assert!(matches!(stats.base.num_rows, Precision::Exact(1000)));
1253        Ok(())
1254    }
1255
1256    #[derive(Debug)]
1257    struct OverrideFilterProvider {
1258        fixed_selectivity: f64,
1259    }
1260
1261    impl StatisticsProvider for OverrideFilterProvider {
1262        fn compute_statistics(
1263            &self,
1264            plan: &dyn ExecutionPlan,
1265            child_stats: &[ExtendedStatistics],
1266        ) -> Result<StatisticsResult> {
1267            if plan.downcast_ref::<FilterExec>().is_some() {
1268                if let Some(&input_rows) = child_stats[0].base.num_rows.get_value() {
1269                    let estimated = (input_rows as f64 * self.fixed_selectivity) as usize;
1270                    Ok(StatisticsResult::Computed(ExtendedStatistics::from(
1271                        Statistics {
1272                            num_rows: Precision::Inexact(estimated),
1273                            total_byte_size: Precision::Absent,
1274                            column_statistics: child_stats[0]
1275                                .base
1276                                .column_statistics
1277                                .clone(),
1278                        },
1279                    )))
1280                } else {
1281                    Ok(StatisticsResult::Delegate)
1282                }
1283            } else {
1284                Ok(StatisticsResult::Delegate)
1285            }
1286        }
1287    }
1288
1289    #[test]
1290    fn test_override_builtin_operator() -> Result<()> {
1291        let mut engine = StatisticsRegistry::new();
1292        engine.register(Arc::new(OverrideFilterProvider {
1293            fixed_selectivity: 0.1,
1294        }));
1295
1296        let source = make_source(1000);
1297        let filter: Arc<dyn ExecutionPlan> =
1298            Arc::new(FilterExec::try_new(lit(true), source)?);
1299
1300        let stats = engine.compute(filter.as_ref())?;
1301        assert!(matches!(stats.base.num_rows, Precision::Inexact(100)));
1302        Ok(())
1303    }
1304
1305    #[test]
1306    fn test_filter_statistics_propagation() -> Result<()> {
1307        let engine = StatisticsRegistry::new();
1308        let source = make_source(1000);
1309        let predicate = lit(true);
1310        let filter: Arc<dyn ExecutionPlan> =
1311            Arc::new(FilterExec::try_new(predicate, source)?);
1312
1313        let stats = engine.compute(filter.as_ref())?;
1314        assert!(stats.base.num_rows.get_value().unwrap_or(&0) <= &1000);
1315        Ok(())
1316    }
1317
1318    #[test]
1319    fn test_filter_adjusts_ndv_by_selectivity() -> Result<()> {
1320        use datafusion_common::ScalarValue;
1321        use datafusion_expr::Operator;
1322        use datafusion_physical_expr::expressions::{
1323            BinaryExpr, Column as PhysColumn, Literal,
1324        };
1325
1326        // Source: 1000 rows, NDV(a)=1000 (unique), NDV(b)=800 (near-unique)
1327        // With NDV close to num_rows, each value has ~1.25 rows, so filtering
1328        // visibly reduces the number of surviving distinct values.
1329        let schema = make_schema(); // "a" Int32, "b" Int32
1330        let col_stats = vec![
1331            {
1332                let mut cs = ColumnStatistics::new_unknown();
1333                cs.distinct_count = Precision::Exact(1000);
1334                cs.min_value = Precision::Exact(ScalarValue::Int32(Some(1)));
1335                cs.max_value = Precision::Exact(ScalarValue::Int32(Some(1000)));
1336                cs
1337            },
1338            {
1339                let mut cs = ColumnStatistics::new_unknown();
1340                cs.distinct_count = Precision::Exact(800);
1341                cs.min_value = Precision::Exact(ScalarValue::Int32(Some(1)));
1342                cs.max_value = Precision::Exact(ScalarValue::Int32(Some(800)));
1343                cs
1344            },
1345        ];
1346        let source: Arc<dyn ExecutionPlan> = Arc::new(MockSourceExec::with_column_stats(
1347            schema,
1348            Precision::Exact(1000),
1349            col_stats,
1350        ));
1351
1352        // Filter: a > 900 (selectivity ~10%, keeps values 901-1000)
1353        let predicate: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
1354            Arc::new(PhysColumn::new("a", 0)),
1355            Operator::Gt,
1356            Arc::new(Literal::new(ScalarValue::Int32(Some(900)))),
1357        ));
1358        let filter: Arc<dyn ExecutionPlan> =
1359            Arc::new(FilterExec::try_new(predicate, source)?);
1360
1361        let registry = StatisticsRegistry::with_providers(vec![
1362            Arc::new(FilterStatisticsProvider),
1363            Arc::new(DefaultStatisticsProvider),
1364        ]);
1365        let stats = registry.compute(filter.as_ref())?;
1366
1367        let output_ndv_a = stats.base.column_statistics[0]
1368            .distinct_count
1369            .get_value()
1370            .copied()
1371            .unwrap_or(0);
1372        let output_ndv_b = stats.base.column_statistics[1]
1373            .distinct_count
1374            .get_value()
1375            .copied()
1376            .unwrap_or(0);
1377
1378        // NDV(a): interval analysis narrows to [901,1000] -> ~100 distinct values
1379        assert!(
1380            output_ndv_a <= 100,
1381            "Expected NDV(a) <= 100 after filter, got {output_ndv_a}"
1382        );
1383        // NDV(b): not in predicate, but selectivity ~10% with 1.25 rows/value
1384        // means many distinct values are lost. ndv_after_selectivity(800, 1000, 0.1)
1385        // gives ~76. Significantly less than the original 800.
1386        assert!(
1387            output_ndv_b < 200,
1388            "Expected NDV(b) < 200 after filter, got {output_ndv_b}"
1389        );
1390        Ok(())
1391    }
1392
1393    #[test]
1394    fn test_projection_statistics_propagation() -> Result<()> {
1395        let engine = StatisticsRegistry::new();
1396        let source = make_source(1000);
1397        let schema = make_schema();
1398        let proj: Arc<dyn ExecutionPlan> = Arc::new(ProjectionExec::try_new(
1399            vec![(col("a", &schema)?, "a".to_string())],
1400            source,
1401        )?);
1402
1403        let stats = engine.compute(proj.as_ref())?;
1404        assert!(matches!(stats.base.num_rows, Precision::Exact(1000)));
1405        Ok(())
1406    }
1407
1408    #[test]
1409    fn test_passthrough_statistics_propagation() -> Result<()> {
1410        use crate::coalesce_partitions::CoalescePartitionsExec;
1411
1412        let engine = StatisticsRegistry::new();
1413        let source = make_source(1000);
1414        let coalesce: Arc<dyn ExecutionPlan> =
1415            Arc::new(CoalescePartitionsExec::new(source));
1416
1417        let stats = engine.compute(coalesce.as_ref())?;
1418        // PassthroughStatisticsProvider should propagate child row count unchanged
1419        assert_eq!(stats.base.num_rows, Precision::Exact(1000));
1420        Ok(())
1421    }
1422
1423    #[test]
1424    fn test_chain_priority() -> Result<()> {
1425        let mut engine = StatisticsRegistry::new();
1426        engine.register(Arc::new(OverrideFilterProvider {
1427            fixed_selectivity: 0.5,
1428        }));
1429        engine.register(Arc::new(CustomStatisticsProvider));
1430
1431        let source = make_source(1000);
1432
1433        // CustomExec handled by CustomStatisticsProvider
1434        let custom: Arc<dyn ExecutionPlan> = Arc::new(CustomExec {
1435            input: Arc::clone(&source),
1436        });
1437        let stats = engine.compute(custom.as_ref())?;
1438        assert!(matches!(stats.base.num_rows, Precision::Exact(1000)));
1439
1440        // FilterExec: CustomStatisticsProvider delegates, OverrideFilterProvider handles
1441        let filter: Arc<dyn ExecutionPlan> =
1442            Arc::new(FilterExec::try_new(lit(true), source)?);
1443        let stats = engine.compute(filter.as_ref())?;
1444        assert!(matches!(stats.base.num_rows, Precision::Inexact(500)));
1445
1446        Ok(())
1447    }
1448
1449    // =========================================================================
1450    // num_distinct_vals Utility Tests
1451    // =========================================================================
1452
1453    #[test]
1454    fn test_num_distinct_vals_basic() {
1455        assert_eq!(num_distinct_vals(0, 100), 0);
1456        assert_eq!(num_distinct_vals(100, 0), 0);
1457        assert_eq!(num_distinct_vals(100, 100), 100);
1458        assert_eq!(num_distinct_vals(100, 200), 100);
1459
1460        let ndv = num_distinct_vals(1000, 100);
1461        assert!((90..=100).contains(&ndv), "Expected ~95, got {ndv}");
1462
1463        let ndv = num_distinct_vals(1000, 500);
1464        assert!((350..=450).contains(&ndv), "Expected ~393, got {ndv}");
1465
1466        let ndv = num_distinct_vals(1_000_000, 10_000);
1467        assert!((9900..=10000).contains(&ndv), "Expected ~9950, got {ndv}");
1468
1469        let ndv = num_distinct_vals(1_000_000, 100);
1470        assert!((99..=100).contains(&ndv), "Expected ~100, got {ndv}");
1471    }
1472
1473    #[test]
1474    fn test_num_distinct_vals_small_domain() {
1475        let ndv = num_distinct_vals(10, 5);
1476        assert!((3..=5).contains(&ndv), "Expected ~4, got {ndv}");
1477
1478        assert_eq!(num_distinct_vals(10, 20), 10);
1479        assert_eq!(num_distinct_vals(10, 1), 1);
1480    }
1481
1482    #[test]
1483    fn test_ndv_after_selectivity() {
1484        let ndv = ndv_after_selectivity(1000, 10000, 0.1);
1485        assert!((600..=700).contains(&ndv), "Expected ~632, got {ndv}");
1486
1487        let ndv = ndv_after_selectivity(1000, 10000, 0.01);
1488        assert!((90..=100).contains(&ndv), "Expected ~95, got {ndv}");
1489
1490        assert_eq!(ndv_after_selectivity(1000, 10000, 0.0), 0);
1491        assert_eq!(ndv_after_selectivity(1000, 10000, 1.0), 1000);
1492        assert_eq!(ndv_after_selectivity(0, 10000, 0.5), 0);
1493    }
1494
1495    // =========================================================================
1496    // AggregateStatisticsProvider tests
1497    // =========================================================================
1498
1499    use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
1500
1501    fn make_source_with_ndv(
1502        num_rows: usize,
1503        col_ndvs: Vec<Option<usize>>,
1504    ) -> Arc<dyn ExecutionPlan> {
1505        let fields: Vec<Field> = col_ndvs
1506            .iter()
1507            .enumerate()
1508            .map(|(i, _)| Field::new(format!("c{i}"), DataType::Int32, false))
1509            .collect();
1510        let schema = Arc::new(Schema::new(fields));
1511        let col_stats = col_ndvs
1512            .into_iter()
1513            .map(|ndv| {
1514                let mut cs = ColumnStatistics::new_unknown();
1515                if let Some(n) = ndv {
1516                    cs.distinct_count = Precision::Exact(n);
1517                }
1518                cs
1519            })
1520            .collect();
1521        Arc::new(MockSourceExec::with_column_stats(
1522            schema,
1523            Precision::Exact(num_rows),
1524            col_stats,
1525        ))
1526    }
1527
1528    fn make_aggregate(
1529        input: Arc<dyn ExecutionPlan>,
1530        group_by: PhysicalGroupBy,
1531    ) -> Result<Arc<dyn ExecutionPlan>> {
1532        Ok(Arc::new(AggregateExec::try_new(
1533            AggregateMode::Single,
1534            group_by,
1535            vec![],
1536            vec![],
1537            Arc::clone(&input),
1538            input.schema(),
1539        )?))
1540    }
1541
1542    #[test]
1543    fn test_aggregate_provider_with_ndv() -> Result<()> {
1544        let source = make_source_with_ndv(100, vec![Some(10)]);
1545        let group_by = PhysicalGroupBy::new_single(vec![(
1546            Arc::new(Column::new("c0", 0)),
1547            "c0".to_string(),
1548        )]);
1549        let agg = make_aggregate(source, group_by)?;
1550
1551        let registry = StatisticsRegistry::with_providers(vec![
1552            Arc::new(AggregateStatisticsProvider),
1553            Arc::new(DefaultStatisticsProvider),
1554        ]);
1555        let stats = registry.compute(agg.as_ref())?;
1556        assert_eq!(stats.base.num_rows, Precision::Inexact(10));
1557        Ok(())
1558    }
1559
1560    #[test]
1561    fn test_aggregate_provider_multi_column() -> Result<()> {
1562        let source = make_source_with_ndv(1000, vec![Some(10), Some(5)]);
1563        let group_by = PhysicalGroupBy::new_single(vec![
1564            (Arc::new(Column::new("c0", 0)), "c0".to_string()),
1565            (Arc::new(Column::new("c1", 1)), "c1".to_string()),
1566        ]);
1567        let agg = make_aggregate(source, group_by)?;
1568
1569        let registry = StatisticsRegistry::with_providers(vec![
1570            Arc::new(AggregateStatisticsProvider),
1571            Arc::new(DefaultStatisticsProvider),
1572        ]);
1573        let stats = registry.compute(agg.as_ref())?;
1574        // 10 * 5 = 50
1575        assert_eq!(stats.base.num_rows, Precision::Inexact(50));
1576        Ok(())
1577    }
1578
1579    #[test]
1580    fn test_aggregate_provider_caps_at_input_rows() -> Result<()> {
1581        // NDV product (100 * 100 = 10_000) exceeds input rows (500)
1582        let source = make_source_with_ndv(500, vec![Some(100), Some(100)]);
1583        let group_by = PhysicalGroupBy::new_single(vec![
1584            (Arc::new(Column::new("c0", 0)), "c0".to_string()),
1585            (Arc::new(Column::new("c1", 1)), "c1".to_string()),
1586        ]);
1587        let agg = make_aggregate(source, group_by)?;
1588
1589        let registry = StatisticsRegistry::with_providers(vec![
1590            Arc::new(AggregateStatisticsProvider),
1591            Arc::new(DefaultStatisticsProvider),
1592        ]);
1593        let stats = registry.compute(agg.as_ref())?;
1594        assert_eq!(stats.base.num_rows, Precision::Inexact(500));
1595        Ok(())
1596    }
1597
1598    #[test]
1599    fn test_aggregate_provider_no_ndv_delegates() -> Result<()> {
1600        // No NDV on the GROUP BY column
1601        let source = make_source_with_ndv(100, vec![None]);
1602        let group_by = PhysicalGroupBy::new_single(vec![(
1603            Arc::new(Column::new("c0", 0)),
1604            "c0".to_string(),
1605        )]);
1606        let agg = make_aggregate(source, group_by)?;
1607
1608        let registry = StatisticsRegistry::with_providers(vec![
1609            Arc::new(AggregateStatisticsProvider),
1610            Arc::new(DefaultStatisticsProvider),
1611        ]);
1612        let stats = registry.compute(agg.as_ref())?;
1613        // Delegates to DefaultStatisticsProvider, which calls partition_statistics
1614        assert!(
1615            stats.base.num_rows.get_value().is_some()
1616                || matches!(stats.base.num_rows, Precision::Absent)
1617        );
1618        Ok(())
1619    }
1620
1621    #[test]
1622    fn test_aggregate_provider_non_column_expr_delegates() -> Result<()> {
1623        let source = make_source_with_ndv(100, vec![Some(10), Some(5)]);
1624        // GROUP BY an expression (c0 + c1), not a simple column ref
1625        let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
1626            Arc::new(Column::new("c0", 0)),
1627            Operator::Plus,
1628            Arc::new(Column::new("c1", 1)),
1629        ));
1630        let group_by = PhysicalGroupBy::new_single(vec![(expr, "sum".to_string())]);
1631        let agg = make_aggregate(source, group_by)?;
1632
1633        let registry = StatisticsRegistry::with_providers(vec![
1634            Arc::new(AggregateStatisticsProvider),
1635            Arc::new(DefaultStatisticsProvider),
1636        ]);
1637        let stats = registry.compute(agg.as_ref())?;
1638        // Should delegate (expression is not a Column)
1639        assert!(
1640            stats.base.num_rows.get_value().is_some()
1641                || matches!(stats.base.num_rows, Precision::Absent)
1642        );
1643        Ok(())
1644    }
1645
1646    #[test]
1647    fn test_aggregate_provider_grouping_sets() -> Result<()> {
1648        let source = make_source_with_ndv(1000, vec![Some(10), Some(5)]);
1649        // GROUPING SETS: (c0, c1), (c0), (c1) -> 3 groups
1650        let group_by = PhysicalGroupBy::new(
1651            vec![
1652                (Arc::new(Column::new("c0", 0)), "c0".to_string()),
1653                (Arc::new(Column::new("c1", 1)), "c1".to_string()),
1654            ],
1655            vec![
1656                (
1657                    Arc::new(Literal::new(ScalarValue::Int32(None))),
1658                    "c0".to_string(),
1659                ),
1660                (
1661                    Arc::new(Literal::new(ScalarValue::Int32(None))),
1662                    "c1".to_string(),
1663                ),
1664            ],
1665            vec![
1666                vec![false, true],  // (c0, NULL) - group by c0 only
1667                vec![true, false],  // (NULL, c1) - group by c1 only
1668                vec![false, false], // (c0, c1)   - group by both
1669            ],
1670            true,
1671        );
1672        let agg = make_aggregate(source, group_by)?;
1673
1674        let registry = StatisticsRegistry::with_providers(vec![
1675            Arc::new(AggregateStatisticsProvider),
1676            Arc::new(DefaultStatisticsProvider),
1677        ]);
1678        let stats = registry.compute(agg.as_ref())?;
1679        // Multiple grouping sets: provider delegates to DefaultStatisticsProvider,
1680        // which calls the built-in partition_statistics for correct per-set
1681        // NDV estimation. The exact value depends on the built-in implementation.
1682        assert!(
1683            stats.base.num_rows.get_value().is_some()
1684                || matches!(stats.base.num_rows, Precision::Absent)
1685        );
1686        Ok(())
1687    }
1688
1689    #[test]
1690    fn test_aggregate_provider_partial_delegates() -> Result<()> {
1691        // Partial aggregates produce per-partition groups; the provider
1692        // should delegate rather than applying global NDV bounds.
1693        let source = make_source_with_ndv(100, vec![Some(10)]);
1694        let group_by = PhysicalGroupBy::new_single(vec![(
1695            Arc::new(Column::new("c0", 0)),
1696            "c0".to_string(),
1697        )]);
1698        let agg: Arc<dyn ExecutionPlan> = Arc::new(AggregateExec::try_new(
1699            AggregateMode::Partial,
1700            group_by,
1701            vec![],
1702            vec![],
1703            Arc::clone(&source),
1704            source.schema(),
1705        )?);
1706
1707        let registry = StatisticsRegistry::with_providers(vec![
1708            Arc::new(AggregateStatisticsProvider),
1709            Arc::new(DefaultStatisticsProvider),
1710        ]);
1711        let stats = registry.compute(agg.as_ref())?;
1712        // Should fall through to DefaultStatisticsProvider (partition_statistics).
1713        // The exact value depends on the built-in implementation.
1714        assert!(
1715            stats.base.num_rows.get_value().is_some()
1716                || matches!(stats.base.num_rows, Precision::Absent)
1717        );
1718        Ok(())
1719    }
1720
1721    // =========================================================================
1722    // JoinStatisticsProvider tests
1723    // =========================================================================
1724
1725    use crate::joins::{HashJoinExec, PartitionMode};
1726    use datafusion_common::{JoinType, NullEquality};
1727
1728    fn make_source_with_ndv_2col(
1729        num_rows: usize,
1730        ndv_a: Option<usize>,
1731    ) -> Arc<dyn ExecutionPlan> {
1732        let schema = make_schema(); // "a" Int32, "b" Int32
1733        let col_stats = vec![
1734            {
1735                let mut cs = ColumnStatistics::new_unknown();
1736                if let Some(n) = ndv_a {
1737                    cs.distinct_count = Precision::Exact(n);
1738                }
1739                cs
1740            },
1741            ColumnStatistics::new_unknown(),
1742        ];
1743        Arc::new(MockSourceExec::with_column_stats(
1744            schema,
1745            Precision::Exact(num_rows),
1746            col_stats,
1747        ))
1748    }
1749
1750    fn make_hash_join(
1751        left: Arc<dyn ExecutionPlan>,
1752        right: Arc<dyn ExecutionPlan>,
1753    ) -> Result<Arc<dyn ExecutionPlan>> {
1754        let _schema = make_schema();
1755        let on: crate::joins::JoinOn = vec![(
1756            Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>,
1757            Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>,
1758        )];
1759        Ok(Arc::new(HashJoinExec::try_new(
1760            left,
1761            right,
1762            on,
1763            None,
1764            &JoinType::Inner,
1765            None,
1766            PartitionMode::CollectLeft,
1767            NullEquality::NullEqualsNull,
1768            false,
1769        )?))
1770    }
1771
1772    #[test]
1773    fn test_join_provider_with_ndv() -> Result<()> {
1774        // left: 1000 rows, NDV(a)=100; right: 500 rows, NDV(a)=50
1775        // expected = 1000 * 500 / max(100, 50) = 5000
1776        let left = make_source_with_ndv_2col(1000, Some(100));
1777        let right = make_source_with_ndv_2col(500, Some(50));
1778        let join = make_hash_join(left, right)?;
1779
1780        let registry = StatisticsRegistry::with_providers(vec![
1781            Arc::new(JoinStatisticsProvider),
1782            Arc::new(DefaultStatisticsProvider),
1783        ]);
1784        let stats = registry.compute(join.as_ref())?;
1785        assert_eq!(stats.base.num_rows, Precision::Inexact(5000));
1786        Ok(())
1787    }
1788
1789    #[test]
1790    fn test_join_provider_uses_actual_key_column_ndv() -> Result<()> {
1791        // Join on column "b" (index 1), NDV only set on "b", not "a".
1792        // Old first()-based code would look up column 0 (a), find no NDV,
1793        // and fall back to Cartesian product. The fix looks up column 1 (b).
1794        // left: 1000 rows, NDV(b)=50; right: 500 rows, NDV(b)=25
1795        // expected = 1000 * 500 / max(50, 25) = 10000
1796        let schema = make_schema(); // "a" Int32, "b" Int32
1797        let make_source_ndv_b =
1798            |num_rows: usize, ndv_b: usize| -> Arc<dyn ExecutionPlan> {
1799                let col_stats = vec![
1800                    ColumnStatistics::new_unknown(), // "a": no NDV
1801                    {
1802                        let mut cs = ColumnStatistics::new_unknown();
1803                        cs.distinct_count = Precision::Exact(ndv_b);
1804                        cs
1805                    },
1806                ];
1807                Arc::new(MockSourceExec::with_column_stats(
1808                    Arc::clone(&schema),
1809                    Precision::Exact(num_rows),
1810                    col_stats,
1811                ))
1812            };
1813
1814        let left = make_source_ndv_b(1000, 50);
1815        let right = make_source_ndv_b(500, 25);
1816
1817        // Join on column "b" (index 1)
1818        let on: crate::joins::JoinOn = vec![(
1819            Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>,
1820            Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>,
1821        )];
1822        let join: Arc<dyn ExecutionPlan> = Arc::new(HashJoinExec::try_new(
1823            left,
1824            right,
1825            on,
1826            None,
1827            &JoinType::Inner,
1828            None,
1829            PartitionMode::CollectLeft,
1830            NullEquality::NullEqualsNull,
1831            false,
1832        )?);
1833
1834        let registry = StatisticsRegistry::with_providers(vec![
1835            Arc::new(JoinStatisticsProvider),
1836            Arc::new(DefaultStatisticsProvider),
1837        ]);
1838        let stats = registry.compute(join.as_ref())?;
1839        assert_eq!(stats.base.num_rows, Precision::Inexact(10_000));
1840        Ok(())
1841    }
1842
1843    #[test]
1844    fn test_join_provider_multi_key_ndv() -> Result<()> {
1845        // Multi-key join: ON a.a = b.a AND a.b = b.b
1846        // left: 1000 rows, NDV(a)=100, NDV(b)=20
1847        // right: 500 rows, NDV(a)=50, NDV(b)=10
1848        // expected = 1000 * 500 / (max(100,50) * max(20,10)) = 500000 / 2000 = 250
1849        let schema = make_schema(); // "a" Int32, "b" Int32
1850        let make_source_2ndv =
1851            |num_rows: usize, ndv_a: usize, ndv_b: usize| -> Arc<dyn ExecutionPlan> {
1852                let col_stats = vec![
1853                    {
1854                        let mut cs = ColumnStatistics::new_unknown();
1855                        cs.distinct_count = Precision::Exact(ndv_a);
1856                        cs
1857                    },
1858                    {
1859                        let mut cs = ColumnStatistics::new_unknown();
1860                        cs.distinct_count = Precision::Exact(ndv_b);
1861                        cs
1862                    },
1863                ];
1864                Arc::new(MockSourceExec::with_column_stats(
1865                    Arc::clone(&schema),
1866                    Precision::Exact(num_rows),
1867                    col_stats,
1868                ))
1869            };
1870
1871        let left = make_source_2ndv(1000, 100, 20);
1872        let right = make_source_2ndv(500, 50, 10);
1873
1874        let on: crate::joins::JoinOn = vec![
1875            (
1876                Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>,
1877                Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>,
1878            ),
1879            (
1880                Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>,
1881                Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>,
1882            ),
1883        ];
1884        let join: Arc<dyn ExecutionPlan> = Arc::new(HashJoinExec::try_new(
1885            left,
1886            right,
1887            on,
1888            None,
1889            &JoinType::Inner,
1890            None,
1891            PartitionMode::CollectLeft,
1892            NullEquality::NullEqualsNull,
1893            false,
1894        )?);
1895
1896        let registry = StatisticsRegistry::with_providers(vec![
1897            Arc::new(JoinStatisticsProvider),
1898            Arc::new(DefaultStatisticsProvider),
1899        ]);
1900        let stats = registry.compute(join.as_ref())?;
1901        assert_eq!(stats.base.num_rows, Precision::Inexact(250));
1902        Ok(())
1903    }
1904
1905    #[test]
1906    fn test_join_provider_fallback_cartesian() -> Result<()> {
1907        // No NDV available -> Cartesian product estimate
1908        let left = make_source_with_ndv_2col(100, None);
1909        let right = make_source_with_ndv_2col(200, None);
1910        let join = make_hash_join(left, right)?;
1911
1912        let registry = StatisticsRegistry::with_providers(vec![
1913            Arc::new(JoinStatisticsProvider),
1914            Arc::new(DefaultStatisticsProvider),
1915        ]);
1916        let stats = registry.compute(join.as_ref())?;
1917        assert_eq!(stats.base.num_rows, Precision::Inexact(20_000));
1918        Ok(())
1919    }
1920
1921    #[test]
1922    fn test_nl_join_delegates() -> Result<()> {
1923        use crate::joins::NestedLoopJoinExec;
1924
1925        // NL join delegates to the built-in (NestedLoopJoinExec may have an
1926        // arbitrary JoinFilter, so the provider cannot safely assume Cartesian).
1927        let left = make_source(100);
1928        let right = make_source(200);
1929        let join: Arc<dyn ExecutionPlan> = Arc::new(NestedLoopJoinExec::try_new(
1930            left,
1931            right,
1932            None,
1933            &JoinType::Inner,
1934            None,
1935        )?);
1936
1937        let registry = StatisticsRegistry::with_providers(vec![
1938            Arc::new(JoinStatisticsProvider),
1939            Arc::new(DefaultStatisticsProvider),
1940        ]);
1941        let stats = registry.compute(join.as_ref())?;
1942        // Provider delegates; result comes from built-in partition_statistics.
1943        assert!(
1944            stats.base.num_rows.get_value().is_some()
1945                || matches!(stats.base.num_rows, Precision::Absent)
1946        );
1947        Ok(())
1948    }
1949
1950    fn make_hash_join_typed(
1951        left: Arc<dyn ExecutionPlan>,
1952        right: Arc<dyn ExecutionPlan>,
1953        join_type: JoinType,
1954    ) -> Result<Arc<dyn ExecutionPlan>> {
1955        let on: crate::joins::JoinOn = vec![(
1956            Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>,
1957            Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>,
1958        )];
1959        Ok(Arc::new(HashJoinExec::try_new(
1960            left,
1961            right,
1962            on,
1963            None,
1964            &join_type,
1965            None,
1966            PartitionMode::CollectLeft,
1967            NullEquality::NullEqualsNull,
1968            false,
1969        )?))
1970    }
1971
1972    fn compute_join_rows(
1973        left_rows: usize,
1974        left_ndv: Option<usize>,
1975        right_rows: usize,
1976        right_ndv: Option<usize>,
1977        join_type: JoinType,
1978    ) -> Result<Precision<usize>> {
1979        let left = make_source_with_ndv_2col(left_rows, left_ndv);
1980        let right = make_source_with_ndv_2col(right_rows, right_ndv);
1981        let join = make_hash_join_typed(left, right, join_type)?;
1982        let registry = StatisticsRegistry::with_providers(vec![
1983            Arc::new(JoinStatisticsProvider),
1984            Arc::new(DefaultStatisticsProvider),
1985        ]);
1986        Ok(registry.compute(join.as_ref())?.base.num_rows)
1987    }
1988
1989    #[test]
1990    fn test_join_provider_left_outer() -> Result<()> {
1991        // left=1000, right=500, NDV(a)=100/50
1992        // inner estimate = 1000*500/100 = 5000, already >= left_rows
1993        // Left outer: max(5000, 1000) = 5000
1994        assert_eq!(
1995            compute_join_rows(1000, Some(100), 500, Some(50), JoinType::Left)?,
1996            Precision::Inexact(5000)
1997        );
1998        // Small inner estimate: left=1000, right=10, NDV=100/100
1999        // inner = 1000*10/100 = 100, left outer = max(100, 1000) = 1000
2000        assert_eq!(
2001            compute_join_rows(1000, Some(100), 10, Some(100), JoinType::Left)?,
2002            Precision::Inexact(1000)
2003        );
2004        Ok(())
2005    }
2006
2007    #[test]
2008    fn test_join_provider_right_outer() -> Result<()> {
2009        // inner = 1000*10/100 = 100, right outer = max(100, 10) = 100
2010        assert_eq!(
2011            compute_join_rows(1000, Some(100), 10, Some(100), JoinType::Right)?,
2012            Precision::Inexact(100)
2013        );
2014        // inner = 10*1000/100 = 100, right outer = max(100, 1000) = 1000
2015        assert_eq!(
2016            compute_join_rows(10, Some(100), 1000, Some(100), JoinType::Right)?,
2017            Precision::Inexact(1000)
2018        );
2019        Ok(())
2020    }
2021
2022    #[test]
2023    fn test_join_provider_semi_join() -> Result<()> {
2024        // inner = 5000, left semi = min(5000, 1000) = 1000
2025        assert_eq!(
2026            compute_join_rows(1000, Some(100), 500, Some(50), JoinType::LeftSemi)?,
2027            Precision::Inexact(1000)
2028        );
2029        // inner = 5000, right semi = min(5000, 500) = 500
2030        assert_eq!(
2031            compute_join_rows(1000, Some(100), 500, Some(50), JoinType::RightSemi)?,
2032            Precision::Inexact(500)
2033        );
2034        // Cartesian fallback (no NDV): inner = 1000*500 = 500000,
2035        // left semi = min(500000, 1000) = 1000 (selectivity = 1.0)
2036        assert_eq!(
2037            compute_join_rows(1000, None, 500, None, JoinType::LeftSemi)?,
2038            Precision::Inexact(1000)
2039        );
2040        Ok(())
2041    }
2042
2043    #[test]
2044    fn test_join_provider_anti_join() -> Result<()> {
2045        // inner = 1000*10/100 = 100, left anti = 1000 - min(100, 1000) = 900
2046        assert_eq!(
2047            compute_join_rows(1000, Some(100), 10, Some(100), JoinType::LeftAnti)?,
2048            Precision::Inexact(900)
2049        );
2050        // inner = 5000, right anti = 500 - min(5000, 500) = 0
2051        assert_eq!(
2052            compute_join_rows(1000, Some(100), 500, Some(50), JoinType::RightAnti)?,
2053            Precision::Inexact(0)
2054        );
2055        Ok(())
2056    }
2057
2058    // =========================================================================
2059    // CrossJoinExec tests (handled by JoinStatisticsProvider)
2060    // =========================================================================
2061
2062    #[test]
2063    fn test_cross_join_provider_exact() -> Result<()> {
2064        use crate::joins::CrossJoinExec;
2065        let left = make_source(100);
2066        let right = make_source(200);
2067        let join: Arc<dyn ExecutionPlan> = Arc::new(CrossJoinExec::new(left, right));
2068
2069        let registry = StatisticsRegistry::with_providers(vec![
2070            Arc::new(JoinStatisticsProvider),
2071            Arc::new(DefaultStatisticsProvider),
2072        ]);
2073        let stats = registry.compute(join.as_ref())?;
2074        // Both inputs have Exact row counts -> result is also Exact
2075        assert_eq!(stats.base.num_rows, Precision::Exact(20_000));
2076        Ok(())
2077    }
2078
2079    // =========================================================================
2080    // LimitStatisticsProvider tests
2081    // =========================================================================
2082
2083    use crate::limit::{GlobalLimitExec, LocalLimitExec};
2084
2085    #[test]
2086    fn test_limit_provider_caps_output() -> Result<()> {
2087        // input > fetch -> capped at fetch
2088        let source = make_source(1000);
2089        let limit: Arc<dyn ExecutionPlan> = Arc::new(LocalLimitExec::new(source, 100));
2090
2091        let registry = StatisticsRegistry::with_providers(vec![
2092            Arc::new(LimitStatisticsProvider),
2093            Arc::new(DefaultStatisticsProvider),
2094        ]);
2095        let stats = registry.compute(limit.as_ref())?;
2096        assert_eq!(stats.base.num_rows, Precision::Exact(100));
2097        Ok(())
2098    }
2099
2100    #[test]
2101    fn test_limit_provider_input_smaller_than_fetch() -> Result<()> {
2102        // input < fetch -> output = input
2103        let source = make_source(50);
2104        let limit: Arc<dyn ExecutionPlan> = Arc::new(LocalLimitExec::new(source, 200));
2105
2106        let registry = StatisticsRegistry::with_providers(vec![
2107            Arc::new(LimitStatisticsProvider),
2108            Arc::new(DefaultStatisticsProvider),
2109        ]);
2110        let stats = registry.compute(limit.as_ref())?;
2111        assert_eq!(stats.base.num_rows, Precision::Exact(50));
2112        Ok(())
2113    }
2114
2115    #[test]
2116    fn test_global_limit_provider_skip_and_fetch() -> Result<()> {
2117        // 1000 rows, skip 200, fetch 100 -> exactly 100
2118        let source = make_source(1000);
2119        let limit: Arc<dyn ExecutionPlan> =
2120            Arc::new(GlobalLimitExec::new(source, 200, Some(100)));
2121
2122        let registry = StatisticsRegistry::with_providers(vec![
2123            Arc::new(LimitStatisticsProvider),
2124            Arc::new(DefaultStatisticsProvider),
2125        ]);
2126        let stats = registry.compute(limit.as_ref())?;
2127        assert_eq!(stats.base.num_rows, Precision::Exact(100));
2128        Ok(())
2129    }
2130
2131    #[test]
2132    fn test_global_limit_provider_skip_exceeds_rows() -> Result<()> {
2133        // 100 rows, skip 200 -> 0 rows (skip > available)
2134        let source = make_source(100);
2135        let limit: Arc<dyn ExecutionPlan> =
2136            Arc::new(GlobalLimitExec::new(source, 200, Some(50)));
2137
2138        let registry = StatisticsRegistry::with_providers(vec![
2139            Arc::new(LimitStatisticsProvider),
2140            Arc::new(DefaultStatisticsProvider),
2141        ]);
2142        let stats = registry.compute(limit.as_ref())?;
2143        assert_eq!(stats.base.num_rows, Precision::Exact(0));
2144        Ok(())
2145    }
2146
2147    #[test]
2148    fn test_limit_provider_inexact_input() -> Result<()> {
2149        // Inexact(1000) with fetch=100: result must stay Inexact, not Exact,
2150        // because the actual row count could be less than 100.
2151        let source = make_source_with_precision(Precision::Inexact(1000));
2152        let limit: Arc<dyn ExecutionPlan> = Arc::new(LocalLimitExec::new(source, 100));
2153
2154        let registry = StatisticsRegistry::with_providers(vec![
2155            Arc::new(LimitStatisticsProvider),
2156            Arc::new(DefaultStatisticsProvider),
2157        ]);
2158        let stats = registry.compute(limit.as_ref())?;
2159        assert_eq!(stats.base.num_rows, Precision::Inexact(100));
2160        Ok(())
2161    }
2162
2163    // =========================================================================
2164    // UnionStatisticsProvider tests
2165    // =========================================================================
2166
2167    use crate::union::UnionExec;
2168
2169    fn make_source_with_precision(num_rows: Precision<usize>) -> Arc<dyn ExecutionPlan> {
2170        Arc::new(MockSourceExec::new(make_schema(), num_rows))
2171    }
2172
2173    #[test]
2174    fn test_union_provider_sums_rows() -> Result<()> {
2175        let union = UnionExec::try_new(vec![make_source(300), make_source(700)])?;
2176
2177        let registry = StatisticsRegistry::with_providers(vec![
2178            Arc::new(UnionStatisticsProvider),
2179            Arc::new(DefaultStatisticsProvider),
2180        ]);
2181        let stats = registry.compute(union.as_ref())?;
2182        assert_eq!(stats.base.num_rows, Precision::Exact(1000));
2183        Ok(())
2184    }
2185
2186    #[test]
2187    fn test_union_provider_three_inputs() -> Result<()> {
2188        let union = UnionExec::try_new(vec![
2189            make_source(100),
2190            make_source(200),
2191            make_source(300),
2192        ])?;
2193
2194        let registry = StatisticsRegistry::with_providers(vec![
2195            Arc::new(UnionStatisticsProvider),
2196            Arc::new(DefaultStatisticsProvider),
2197        ]);
2198        let stats = registry.compute(union.as_ref())?;
2199        assert_eq!(stats.base.num_rows, Precision::Exact(600));
2200        Ok(())
2201    }
2202
2203    #[test]
2204    fn test_union_provider_absent_propagates() -> Result<()> {
2205        // One input with unknown row count -> result must be Absent, not Inexact(300)
2206        let union = UnionExec::try_new(vec![
2207            make_source(300),
2208            make_source_with_precision(Precision::Absent),
2209        ])?;
2210
2211        let registry = StatisticsRegistry::with_providers(vec![
2212            Arc::new(UnionStatisticsProvider),
2213            Arc::new(DefaultStatisticsProvider),
2214        ]);
2215        let stats = registry.compute(union.as_ref())?;
2216        assert_eq!(stats.base.num_rows, Precision::Absent);
2217        Ok(())
2218    }
2219
2220    // =========================================================================
2221    // ClosureStatisticsProvider tests
2222    // =========================================================================
2223
2224    #[test]
2225    fn test_closure_provider_basic() -> Result<()> {
2226        // Override all FilterExec stats with a fixed row count
2227        let provider = ClosureStatisticsProvider::new(|plan, _child_stats| {
2228            if plan.downcast_ref::<FilterExec>().is_some() {
2229                Ok(StatisticsResult::Computed(ExtendedStatistics::from(
2230                    Statistics {
2231                        num_rows: Precision::Inexact(42),
2232                        total_byte_size: Precision::Absent,
2233                        column_statistics: vec![],
2234                    },
2235                )))
2236            } else {
2237                Ok(StatisticsResult::Delegate)
2238            }
2239        });
2240
2241        let registry = StatisticsRegistry::with_providers(vec![
2242            Arc::new(provider),
2243            Arc::new(DefaultStatisticsProvider),
2244        ]);
2245
2246        let source = make_source(1000);
2247        let filter: Arc<dyn ExecutionPlan> =
2248            Arc::new(FilterExec::try_new(lit(true), source)?);
2249        let stats = registry.compute(filter.as_ref())?;
2250        assert_eq!(stats.base.num_rows, Precision::Inexact(42));
2251        Ok(())
2252    }
2253
2254    #[test]
2255    fn test_closure_provider_distinguishes_nodes_by_child_stats() -> Result<()> {
2256        // Two FilterExec nodes with different input sizes.
2257        // The closure uses the child row count as a proxy to distinguish them,
2258        // which mirrors the cardinality feedback use case where you match a
2259        // runtime-observed count to the right node in the plan tree.
2260        let provider = ClosureStatisticsProvider::new(|plan, child_stats| {
2261            if plan.downcast_ref::<FilterExec>().is_none() {
2262                return Ok(StatisticsResult::Delegate);
2263            }
2264            match child_stats[0].base.num_rows.get_value().copied() {
2265                Some(500) => Ok(StatisticsResult::Computed(ExtendedStatistics::from(
2266                    Statistics {
2267                        num_rows: Precision::Inexact(100),
2268                        total_byte_size: Precision::Absent,
2269                        column_statistics: vec![],
2270                    },
2271                ))),
2272                Some(200) => Ok(StatisticsResult::Computed(ExtendedStatistics::from(
2273                    Statistics {
2274                        num_rows: Precision::Inexact(50),
2275                        total_byte_size: Precision::Absent,
2276                        column_statistics: vec![],
2277                    },
2278                ))),
2279                _ => Ok(StatisticsResult::Delegate),
2280            }
2281        });
2282
2283        let registry = StatisticsRegistry::with_providers(vec![Arc::new(provider)]);
2284
2285        let filter_a: Arc<dyn ExecutionPlan> =
2286            Arc::new(FilterExec::try_new(lit(true), make_source(500))?);
2287        let filter_b: Arc<dyn ExecutionPlan> =
2288            Arc::new(FilterExec::try_new(lit(true), make_source(200))?);
2289
2290        let stats_a = registry.compute(filter_a.as_ref())?;
2291        let stats_b = registry.compute(filter_b.as_ref())?;
2292
2293        assert_eq!(stats_a.base.num_rows, Precision::Inexact(100));
2294        assert_eq!(stats_b.base.num_rows, Precision::Inexact(50));
2295        Ok(())
2296    }
2297}