Skip to main content

samkhya_datafusion/
table_provider.rs

1//! `SamkhyaTableProvider` — the primary integration point for injecting
2//! samkhya-corrected column statistics into DataFusion's query planning.
3//!
4//! # Wrapping point: `TableProvider::statistics()`
5//!
6//! DataFusion attaches statistics to table providers, not to logical-plan
7//! nodes. The [`TableProvider`] trait exposes a `statistics()` hook
8//! (returning `Option<Statistics>`) that the planner consults when reasoning
9//! about cardinality, join order, and filter selectivity. Rewriting a
10//! `LogicalPlan` to "inject" stats is the wrong layer — that is observe-only
11//! plumbing. The right layer is a `TableProvider` shim that delegates every
12//! method to an inner provider *except* `statistics()`, where it folds in
13//! samkhya's feedback-driven corrections.
14//!
15//! We considered three wrapping points and chose the first:
16//!
17//! 1. **`TableProvider::statistics()`** (this module). Clean, stable surface
18//!    in DataFusion 46. The planner calls it during analysis. Every adapter
19//!    (Parquet, CSV, MemTable, Iceberg) flows through the same hook, so the
20//!    shim is provider-agnostic.
21//! 2. `ExecutionPlan::statistics()`. Lower in the stack — would require
22//!    wrapping the scan-side `ExecutionPlan` returned from `scan()`. Useful
23//!    when the inner provider's logical stats are absent but its physical
24//!    plan has them; not our situation today.
25//! 3. `OptimizerRule` rewriting `TableScan::source`. The original scaffold
26//!    direction. The rewrite must construct a new `TableSource` (the logical
27//!    counterpart of `TableProvider`) — duplicate state, version-fragile,
28//!    and never propagates into the physical layer where the planner
29//!    actually consults stats. Kept around as observe-only telemetry
30//!    ([`crate::SamkhyaOptimizerRule`]).
31//!
32//! # LpBound posture
33//!
34//! Every value translated into DataFusion's `Precision<T>` is wrapped as
35//! [`Precision::Inexact`]. samkhya's corrections are feedback-driven
36//! estimates clamped by the LpBound pessimistic ceiling; they are never
37//! exact catalog counts. `Inexact` is the precision DataFusion's
38//! cost-based optimizer treats as "use this, but do not assume zero error".
39
40use std::any::Any;
41use std::borrow::Cow;
42use std::collections::HashMap;
43use std::sync::Arc;
44use std::sync::atomic::{AtomicUsize, Ordering};
45
46use async_trait::async_trait;
47use datafusion::arrow::datatypes::SchemaRef;
48use datafusion::catalog::Session;
49use datafusion::common::stats::Precision;
50use datafusion::common::{ColumnStatistics, Constraints, Result, Statistics};
51use datafusion::datasource::{TableProvider, TableType};
52use datafusion::logical_expr::dml::InsertOp;
53use datafusion::logical_expr::{Expr, LogicalPlan, TableProviderFilterPushDown};
54use datafusion::physical_plan::ExecutionPlan;
55use samkhya_core::stats::ColumnStats;
56
57use crate::physical_plan::SamkhyaStatsExec;
58use crate::stats_provider::to_datafusion_column_statistics;
59
60/// A [`TableProvider`] wrapper that overrides `statistics()` with
61/// samkhya-corrected column statistics while delegating every other method
62/// to the inner provider.
63///
64/// # Builder
65///
66/// ```ignore
67/// use std::sync::Arc;
68/// use samkhya_datafusion::SamkhyaTableProvider;
69/// use samkhya_core::stats::ColumnStats;
70///
71/// let wrapped = SamkhyaTableProvider::new(Arc::new(inner))
72///     .with_column_stats(0, ColumnStats::new().with_row_count(999).with_distinct_count(42));
73/// ```
74///
75/// # Stats fold semantics
76///
77/// `statistics()` builds a `Statistics` whose per-column entries come from
78/// the samkhya override map where present, falling back to the inner
79/// provider's stats (or `ColumnStatistics::new_unknown()` if the inner
80/// provider returns `None`). Table-level `num_rows` is taken from the
81/// override map's most authoritative `row_count`: the maximum across all
82/// override entries, since samkhya's per-column stats describe the same
83/// underlying relation. If no override carries a row count, the inner
84/// provider's `num_rows` is preserved.
85#[derive(Debug)]
86pub struct SamkhyaTableProvider {
87    inner: Arc<dyn TableProvider>,
88    overrides: HashMap<usize, ColumnStats>,
89    /// Number of times `statistics()` has been invoked by the planner.
90    /// Exposed for integration tests; not part of the public optimization
91    /// contract.
92    stats_calls: AtomicUsize,
93}
94
95impl SamkhyaTableProvider {
96    /// Wrap an existing provider. No overrides are installed until
97    /// [`Self::with_column_stats`] is called.
98    pub fn new(inner: Arc<dyn TableProvider>) -> Self {
99        Self {
100            inner,
101            overrides: HashMap::new(),
102            stats_calls: AtomicUsize::new(0),
103        }
104    }
105
106    /// Install a samkhya override for the column at `col_idx`.
107    ///
108    /// Indices refer to positions in the inner provider's [`SchemaRef`].
109    /// Existing overrides for the same index are replaced.
110    pub fn with_column_stats(mut self, col_idx: usize, stats: ColumnStats) -> Self {
111        self.overrides.insert(col_idx, stats);
112        self
113    }
114
115    /// Number of times `statistics()` has been called on this wrapper.
116    ///
117    /// Useful for assertions in integration tests that verify the planner
118    /// actually consulted the corrected stats.
119    pub fn stats_call_count(&self) -> usize {
120        self.stats_calls.load(Ordering::SeqCst)
121    }
122
123    /// Borrow the override map. Read-only access for diagnostics.
124    pub fn overrides(&self) -> &HashMap<usize, ColumnStats> {
125        &self.overrides
126    }
127}
128
129#[async_trait]
130impl TableProvider for SamkhyaTableProvider {
131    fn as_any(&self) -> &dyn Any {
132        self
133    }
134
135    fn schema(&self) -> SchemaRef {
136        self.inner.schema()
137    }
138
139    fn constraints(&self) -> Option<&Constraints> {
140        self.inner.constraints()
141    }
142
143    fn table_type(&self) -> TableType {
144        self.inner.table_type()
145    }
146
147    fn get_table_definition(&self) -> Option<&str> {
148        self.inner.get_table_definition()
149    }
150
151    fn get_logical_plan(&self) -> Option<Cow<'_, LogicalPlan>> {
152        self.inner.get_logical_plan()
153    }
154
155    fn get_column_default(&self, column: &str) -> Option<&Expr> {
156        self.inner.get_column_default(column)
157    }
158
159    async fn scan(
160        &self,
161        state: &dyn Session,
162        projection: Option<&Vec<usize>>,
163        filters: &[Expr],
164        limit: Option<usize>,
165    ) -> Result<Arc<dyn ExecutionPlan>> {
166        // Ask the inner provider for its native scan exec, then wrap it
167        // in `SamkhyaStatsExec` so the physical layer publishes the
168        // samkhya-corrected `Statistics` to every downstream operator.
169        //
170        // This is the actual injection path: DataFusion 46's mainline
171        // planner does not consult `TableProvider::statistics()` when
172        // building the physical plan — it calls `scan()` and trusts the
173        // returned `ExecutionPlan::statistics()`. So the only reliable
174        // way to flow corrected row counts into
175        // `physical.statistics()?.num_rows` is to override at the exec
176        // level, here.
177        //
178        // If we have no overrides installed we still wrap, using the
179        // statistics() fold as-is — the cost is one cheap delegation
180        // call per execute()/statistics() and the inner provider's
181        // values are preserved by the merge in `self.statistics()`.
182        let inner_plan = self.inner.scan(state, projection, filters, limit).await?;
183
184        // Project the table-level Statistics onto the scan's *output*
185        // schema (which honours `projection`), so the wrapped exec
186        // reports column_statistics aligned to the columns it actually
187        // emits — not the full table schema. This matches what
188        // `TableProvider`-aware execs (`DataSourceExec`) already do.
189        let full_stats = self
190            .statistics()
191            .unwrap_or_else(|| Statistics::new_unknown(self.inner.schema().as_ref()));
192        let output_stats = full_stats.project(projection);
193
194        Ok(Arc::new(SamkhyaStatsExec::new(inner_plan, output_stats)))
195    }
196
197    fn supports_filters_pushdown(
198        &self,
199        filters: &[&Expr],
200    ) -> Result<Vec<TableProviderFilterPushDown>> {
201        self.inner.supports_filters_pushdown(filters)
202    }
203
204    /// Fold samkhya overrides into the inner provider's `Statistics`.
205    ///
206    /// Schema order is preserved: column `i` in the returned
207    /// `column_statistics` corresponds to field `i` of `self.schema()`.
208    fn statistics(&self) -> Option<Statistics> {
209        // Record the call so tests can assert the planner consulted us.
210        self.stats_calls.fetch_add(1, Ordering::SeqCst);
211
212        let schema = self.inner.schema();
213        let n_fields = schema.fields().len();
214
215        // Start from the inner provider's stats; fall back to an unknown
216        // skeleton sized to the schema so we always return Some(_).
217        let mut base = self
218            .inner
219            .statistics()
220            .unwrap_or_else(|| Statistics::new_unknown(schema.as_ref()));
221
222        // Defensive: if the inner provider returned a column_statistics vec
223        // whose length disagrees with the schema, normalise to schema size.
224        if base.column_statistics.len() != n_fields {
225            base.column_statistics = Statistics::unknown_column(schema.as_ref());
226        }
227
228        // Per-column merge: override wins where present, inner is preserved
229        // otherwise. samkhya values are translated as Inexact per the
230        // LpBound conservative posture.
231        for (col_idx, override_stats) in &self.overrides {
232            if *col_idx >= n_fields {
233                // Index out of range — skip rather than panic; this can
234                // happen if the schema changes under us.
235                continue;
236            }
237            let translated = to_datafusion_column_statistics(override_stats);
238            base.column_statistics[*col_idx] =
239                merge_column_stats(base.column_statistics[*col_idx].clone(), translated);
240        }
241
242        // Table-level row count: take the max row_count across overrides
243        // (they all describe the same relation, so any populated value is
244        // a corrected estimate of |R|). If no override carries a row
245        // count, keep the inner provider's value.
246        //
247        // WAVE5-RC2: plan-memory-monotonic guard. Never publish a row
248        // count smaller than the inner provider's native estimate. The
249        // hash-join build-side sizing in DataFusion 46 picks the smaller
250        // side as the build side; if samkhya under-estimates and the
251        // actual data is much larger, the build hash table grows past
252        // its sized allocation and the planner walks into an OOM. Capping
253        // the published row count at `max(samkhya, native)` preserves
254        // samkhya's win when it has a larger / more accurate NDV-derived
255        // row count, while never pushing the planner toward a smaller
256        // build side than it would have chosen with no samkhya input.
257        // Symmetric guard on `SamkhyaStatsExec::statistics()` enforces
258        // the same invariant at the physical layer.
259        let override_row_count = self.overrides.values().filter_map(|s| s.row_count).max();
260        if let Some(rc) = override_row_count {
261            let rc_usize = rc as usize;
262            let monotone_rc = match base.num_rows {
263                Precision::Exact(n) | Precision::Inexact(n) => rc_usize.max(n),
264                Precision::Absent => rc_usize,
265            };
266            base.num_rows = Precision::Inexact(monotone_rc);
267            // Total byte size: if the inner provider reported it, relax to
268            // inexact since the row count has shifted; otherwise leave
269            // absent.
270            base.total_byte_size = match base.total_byte_size {
271                Precision::Exact(n) | Precision::Inexact(n) => Precision::Inexact(n),
272                Precision::Absent => Precision::Absent,
273            };
274        }
275
276        Some(base)
277    }
278
279    async fn insert_into(
280        &self,
281        state: &dyn Session,
282        input: Arc<dyn ExecutionPlan>,
283        insert_op: InsertOp,
284    ) -> Result<Arc<dyn ExecutionPlan>> {
285        self.inner.insert_into(state, input, insert_op).await
286    }
287}
288
289/// Merge a samkhya-translated `ColumnStatistics` over a base one.
290///
291/// Fields where the override is `Precision::Absent` fall through to the
292/// base. Fields where the override carries an `Inexact` value win for
293/// null_count / max_value / min_value / sum_value. **`distinct_count`
294/// applies the WAVE5-RC2 plan-memory-monotonic guard** — the published
295/// value is `max(samkhya_ndv, native_ndv)`. NDV drives hash-join
296/// build-side sizing; never publishing a smaller distinct count than
297/// DataFusion's native estimate prevents the corrected arm from
298/// pushing the planner toward a smaller build hash table than it would
299/// have chosen with no samkhya input.
300fn merge_column_stats(base: ColumnStatistics, ovr: ColumnStatistics) -> ColumnStatistics {
301    ColumnStatistics {
302        null_count: pick(base.null_count, ovr.null_count),
303        max_value: pick(base.max_value, ovr.max_value),
304        min_value: pick(base.min_value, ovr.min_value),
305        sum_value: pick(base.sum_value, ovr.sum_value),
306        distinct_count: pick_max_usize(base.distinct_count, ovr.distinct_count),
307    }
308}
309
310/// Plan-memory-monotonic merge for `Precision<usize>` cardinality
311/// fields. Returns `Precision::Inexact(max(base, ovr))` when both
312/// carry a value, the present one when only one does, and
313/// `Precision::Absent` when neither does. Used for `distinct_count`
314/// merges so samkhya never publishes an NDV smaller than the inner
315/// provider would have on its own.
316fn pick_max_usize(base: Precision<usize>, ovr: Precision<usize>) -> Precision<usize> {
317    let base_val = match base {
318        Precision::Exact(n) | Precision::Inexact(n) => Some(n),
319        Precision::Absent => None,
320    };
321    let ovr_val = match ovr {
322        Precision::Exact(n) | Precision::Inexact(n) => Some(n),
323        Precision::Absent => None,
324    };
325    match (base_val, ovr_val) {
326        (Some(b), Some(o)) => Precision::Inexact(b.max(o)),
327        (Some(b), None) => Precision::Inexact(b),
328        (None, Some(o)) => Precision::Inexact(o),
329        (None, None) => Precision::Absent,
330    }
331}
332
333fn pick<T>(base: Precision<T>, ovr: Precision<T>) -> Precision<T>
334where
335    T: std::fmt::Debug + Clone + PartialEq + Eq + PartialOrd,
336{
337    match ovr {
338        Precision::Absent => base,
339        other => other,
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use datafusion::arrow::array::Int64Array;
347    use datafusion::arrow::datatypes::{DataType, Field, Schema};
348    use datafusion::arrow::record_batch::RecordBatch;
349    use datafusion::datasource::MemTable;
350
351    fn tiny_mem_table() -> Arc<MemTable> {
352        let schema = Arc::new(Schema::new(vec![
353            Field::new("a", DataType::Int64, false),
354            Field::new("b", DataType::Int64, false),
355        ]));
356        let batch = RecordBatch::try_new(
357            Arc::clone(&schema),
358            vec![
359                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
360                Arc::new(Int64Array::from(vec![10, 20, 30, 40, 50])),
361            ],
362        )
363        .unwrap();
364        Arc::new(MemTable::try_new(schema, vec![vec![batch]]).unwrap())
365    }
366
367    #[test]
368    fn builder_records_overrides() {
369        let inner = tiny_mem_table();
370        let wrapped = SamkhyaTableProvider::new(inner)
371            .with_column_stats(0, ColumnStats::new().with_row_count(999));
372        assert_eq!(wrapped.overrides().len(), 1);
373        assert_eq!(wrapped.overrides()[&0].row_count, Some(999));
374    }
375
376    #[test]
377    fn statistics_overrides_row_count() {
378        let inner = tiny_mem_table();
379        let wrapped = SamkhyaTableProvider::new(inner).with_column_stats(
380            0,
381            ColumnStats::new()
382                .with_row_count(999)
383                .with_distinct_count(42),
384        );
385        let stats = wrapped.statistics().expect("statistics present");
386        assert_eq!(stats.num_rows, Precision::Inexact(999));
387        assert_eq!(
388            stats.column_statistics[0].distinct_count,
389            Precision::Inexact(42)
390        );
391        assert_eq!(wrapped.stats_call_count(), 1);
392    }
393
394    #[test]
395    fn statistics_falls_back_for_unoverridden_columns() {
396        let inner = tiny_mem_table();
397        let wrapped = SamkhyaTableProvider::new(inner)
398            .with_column_stats(0, ColumnStats::new().with_distinct_count(7));
399        let stats = wrapped.statistics().expect("statistics present");
400        assert_eq!(
401            stats.column_statistics[0].distinct_count,
402            Precision::Inexact(7)
403        );
404        // Column 1 has no override and the inner MemTable does not report
405        // stats — so the slot stays at Absent.
406        assert_eq!(stats.column_statistics[1].distinct_count, Precision::Absent);
407    }
408
409    #[test]
410    fn out_of_range_override_is_ignored() {
411        let inner = tiny_mem_table();
412        let wrapped = SamkhyaTableProvider::new(inner)
413            .with_column_stats(99, ColumnStats::new().with_distinct_count(123));
414        // No panic, statistics still produced.
415        let stats = wrapped.statistics().expect("statistics present");
416        assert_eq!(stats.column_statistics.len(), 2);
417    }
418
419    /// WAVE5-RC2: when the inner provider's native row count exceeds the
420    /// samkhya override, the published value is the inner provider's
421    /// estimate, not the (smaller) samkhya value. Prevents the planner
422    /// from picking a smaller hash-join build side than baseline would.
423    ///
424    /// Uses a minimal mock provider that returns a known
425    /// `Precision::Inexact(5)` for num_rows, since `MemTable`'s default
426    /// stats path leaves num_rows as `Precision::Absent` (which would
427    /// trip the fallback branch, not the monotone-cap branch).
428    #[test]
429    fn statistics_row_count_caps_at_max_of_samkhya_and_native() {
430        use async_trait::async_trait;
431        use datafusion::catalog::Session;
432        use datafusion::common::Result as DfResult;
433        use datafusion::datasource::{TableProvider, TableType};
434        use datafusion::logical_expr::Expr;
435        use datafusion::physical_plan::ExecutionPlan;
436
437        #[derive(Debug)]
438        struct MockProvider {
439            schema: SchemaRef,
440            native_rows: usize,
441        }
442
443        #[async_trait]
444        impl TableProvider for MockProvider {
445            fn as_any(&self) -> &dyn Any {
446                self
447            }
448            fn schema(&self) -> SchemaRef {
449                Arc::clone(&self.schema)
450            }
451            fn table_type(&self) -> TableType {
452                TableType::Base
453            }
454            async fn scan(
455                &self,
456                _state: &dyn Session,
457                _projection: Option<&Vec<usize>>,
458                _filters: &[Expr],
459                _limit: Option<usize>,
460            ) -> DfResult<Arc<dyn ExecutionPlan>> {
461                unreachable!("scan not exercised by this test")
462            }
463            fn statistics(&self) -> Option<Statistics> {
464                let mut s = Statistics::new_unknown(self.schema.as_ref());
465                s.num_rows = Precision::Inexact(self.native_rows);
466                Some(s)
467            }
468        }
469
470        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
471        let inner: Arc<dyn TableProvider> = Arc::new(MockProvider {
472            schema: Arc::clone(&schema),
473            native_rows: 5,
474        });
475        let wrapped = SamkhyaTableProvider::new(inner)
476            .with_column_stats(0, ColumnStats::new().with_row_count(3));
477        let stats = wrapped.statistics().expect("statistics present");
478        assert_eq!(
479            stats.num_rows,
480            Precision::Inexact(5),
481            "monotone cap must publish max(samkhya=3, native=5)=5, not the smaller samkhya estimate"
482        );
483    }
484
485    /// WAVE5-RC2: symmetric column-level guard. When samkhya's NDV
486    /// override is smaller than the inner provider's native NDV, publish
487    /// the native (larger) value to keep hash-join build sides on the
488    /// safe side.
489    #[test]
490    fn statistics_distinct_count_caps_at_max_of_samkhya_and_native() {
491        // Hand-construct a base ColumnStatistics with a known native
492        // distinct_count and feed it through merge_column_stats with a
493        // smaller samkhya override. Expected: max() wins.
494        let base = ColumnStatistics {
495            null_count: Precision::Absent,
496            max_value: Precision::Absent,
497            min_value: Precision::Absent,
498            sum_value: Precision::Absent,
499            distinct_count: Precision::Inexact(1000),
500        };
501        let ovr = ColumnStatistics {
502            null_count: Precision::Absent,
503            max_value: Precision::Absent,
504            min_value: Precision::Absent,
505            sum_value: Precision::Absent,
506            distinct_count: Precision::Inexact(50),
507        };
508        let merged = merge_column_stats(base, ovr);
509        assert_eq!(
510            merged.distinct_count,
511            Precision::Inexact(1000),
512            "merge must publish max(samkhya, native) distinct_count"
513        );
514    }
515}