use std::any::Any;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::catalog::Session;
use datafusion::common::stats::Precision;
use datafusion::common::{ColumnStatistics, Constraints, Result, Statistics};
use datafusion::datasource::{TableProvider, TableType};
use datafusion::logical_expr::dml::InsertOp;
use datafusion::logical_expr::{Expr, LogicalPlan, TableProviderFilterPushDown};
use datafusion::physical_plan::ExecutionPlan;
use samkhya_core::stats::ColumnStats;
use crate::physical_plan::SamkhyaStatsExec;
use crate::stats_provider::to_datafusion_column_statistics;
#[derive(Debug)]
pub struct SamkhyaTableProvider {
inner: Arc<dyn TableProvider>,
overrides: HashMap<usize, ColumnStats>,
stats_calls: AtomicUsize,
}
impl SamkhyaTableProvider {
pub fn new(inner: Arc<dyn TableProvider>) -> Self {
Self {
inner,
overrides: HashMap::new(),
stats_calls: AtomicUsize::new(0),
}
}
pub fn with_column_stats(mut self, col_idx: usize, stats: ColumnStats) -> Self {
self.overrides.insert(col_idx, stats);
self
}
pub fn stats_call_count(&self) -> usize {
self.stats_calls.load(Ordering::SeqCst)
}
pub fn overrides(&self) -> &HashMap<usize, ColumnStats> {
&self.overrides
}
}
#[async_trait]
impl TableProvider for SamkhyaTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
fn constraints(&self) -> Option<&Constraints> {
self.inner.constraints()
}
fn table_type(&self) -> TableType {
self.inner.table_type()
}
fn get_table_definition(&self) -> Option<&str> {
self.inner.get_table_definition()
}
fn get_logical_plan(&self) -> Option<Cow<'_, LogicalPlan>> {
self.inner.get_logical_plan()
}
fn get_column_default(&self, column: &str) -> Option<&Expr> {
self.inner.get_column_default(column)
}
async fn scan(
&self,
state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let inner_plan = self.inner.scan(state, projection, filters, limit).await?;
let full_stats = self
.statistics()
.unwrap_or_else(|| Statistics::new_unknown(self.inner.schema().as_ref()));
let output_stats = full_stats.project(projection);
Ok(Arc::new(SamkhyaStatsExec::new(inner_plan, output_stats)))
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> Result<Vec<TableProviderFilterPushDown>> {
self.inner.supports_filters_pushdown(filters)
}
fn statistics(&self) -> Option<Statistics> {
self.stats_calls.fetch_add(1, Ordering::SeqCst);
let schema = self.inner.schema();
let n_fields = schema.fields().len();
let mut base = self
.inner
.statistics()
.unwrap_or_else(|| Statistics::new_unknown(schema.as_ref()));
if base.column_statistics.len() != n_fields {
base.column_statistics = Statistics::unknown_column(schema.as_ref());
}
for (col_idx, override_stats) in &self.overrides {
if *col_idx >= n_fields {
continue;
}
let translated = to_datafusion_column_statistics(override_stats);
base.column_statistics[*col_idx] =
merge_column_stats(base.column_statistics[*col_idx].clone(), translated);
}
let override_row_count = self.overrides.values().filter_map(|s| s.row_count).max();
if let Some(rc) = override_row_count {
let rc_usize = rc as usize;
let monotone_rc = match base.num_rows {
Precision::Exact(n) | Precision::Inexact(n) => rc_usize.max(n),
Precision::Absent => rc_usize,
};
base.num_rows = Precision::Inexact(monotone_rc);
base.total_byte_size = match base.total_byte_size {
Precision::Exact(n) | Precision::Inexact(n) => Precision::Inexact(n),
Precision::Absent => Precision::Absent,
};
}
Some(base)
}
async fn insert_into(
&self,
state: &dyn Session,
input: Arc<dyn ExecutionPlan>,
insert_op: InsertOp,
) -> Result<Arc<dyn ExecutionPlan>> {
self.inner.insert_into(state, input, insert_op).await
}
}
fn merge_column_stats(base: ColumnStatistics, ovr: ColumnStatistics) -> ColumnStatistics {
ColumnStatistics {
null_count: pick(base.null_count, ovr.null_count),
max_value: pick(base.max_value, ovr.max_value),
min_value: pick(base.min_value, ovr.min_value),
sum_value: pick(base.sum_value, ovr.sum_value),
distinct_count: pick_max_usize(base.distinct_count, ovr.distinct_count),
}
}
fn pick_max_usize(base: Precision<usize>, ovr: Precision<usize>) -> Precision<usize> {
let base_val = match base {
Precision::Exact(n) | Precision::Inexact(n) => Some(n),
Precision::Absent => None,
};
let ovr_val = match ovr {
Precision::Exact(n) | Precision::Inexact(n) => Some(n),
Precision::Absent => None,
};
match (base_val, ovr_val) {
(Some(b), Some(o)) => Precision::Inexact(b.max(o)),
(Some(b), None) => Precision::Inexact(b),
(None, Some(o)) => Precision::Inexact(o),
(None, None) => Precision::Absent,
}
}
fn pick<T>(base: Precision<T>, ovr: Precision<T>) -> Precision<T>
where
T: std::fmt::Debug + Clone + PartialEq + Eq + PartialOrd,
{
match ovr {
Precision::Absent => base,
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::array::Int64Array;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
fn tiny_mem_table() -> Arc<MemTable> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Int64, false),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(Int64Array::from(vec![10, 20, 30, 40, 50])),
],
)
.unwrap();
Arc::new(MemTable::try_new(schema, vec![vec![batch]]).unwrap())
}
#[test]
fn builder_records_overrides() {
let inner = tiny_mem_table();
let wrapped = SamkhyaTableProvider::new(inner)
.with_column_stats(0, ColumnStats::new().with_row_count(999));
assert_eq!(wrapped.overrides().len(), 1);
assert_eq!(wrapped.overrides()[&0].row_count, Some(999));
}
#[test]
fn statistics_overrides_row_count() {
let inner = tiny_mem_table();
let wrapped = SamkhyaTableProvider::new(inner).with_column_stats(
0,
ColumnStats::new()
.with_row_count(999)
.with_distinct_count(42),
);
let stats = wrapped.statistics().expect("statistics present");
assert_eq!(stats.num_rows, Precision::Inexact(999));
assert_eq!(
stats.column_statistics[0].distinct_count,
Precision::Inexact(42)
);
assert_eq!(wrapped.stats_call_count(), 1);
}
#[test]
fn statistics_falls_back_for_unoverridden_columns() {
let inner = tiny_mem_table();
let wrapped = SamkhyaTableProvider::new(inner)
.with_column_stats(0, ColumnStats::new().with_distinct_count(7));
let stats = wrapped.statistics().expect("statistics present");
assert_eq!(
stats.column_statistics[0].distinct_count,
Precision::Inexact(7)
);
assert_eq!(stats.column_statistics[1].distinct_count, Precision::Absent);
}
#[test]
fn out_of_range_override_is_ignored() {
let inner = tiny_mem_table();
let wrapped = SamkhyaTableProvider::new(inner)
.with_column_stats(99, ColumnStats::new().with_distinct_count(123));
let stats = wrapped.statistics().expect("statistics present");
assert_eq!(stats.column_statistics.len(), 2);
}
#[test]
fn statistics_row_count_caps_at_max_of_samkhya_and_native() {
use async_trait::async_trait;
use datafusion::catalog::Session;
use datafusion::common::Result as DfResult;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::logical_expr::Expr;
use datafusion::physical_plan::ExecutionPlan;
#[derive(Debug)]
struct MockProvider {
schema: SchemaRef,
native_rows: usize,
}
#[async_trait]
impl TableProvider for MockProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
_projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
unreachable!("scan not exercised by this test")
}
fn statistics(&self) -> Option<Statistics> {
let mut s = Statistics::new_unknown(self.schema.as_ref());
s.num_rows = Precision::Inexact(self.native_rows);
Some(s)
}
}
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
let inner: Arc<dyn TableProvider> = Arc::new(MockProvider {
schema: Arc::clone(&schema),
native_rows: 5,
});
let wrapped = SamkhyaTableProvider::new(inner)
.with_column_stats(0, ColumnStats::new().with_row_count(3));
let stats = wrapped.statistics().expect("statistics present");
assert_eq!(
stats.num_rows,
Precision::Inexact(5),
"monotone cap must publish max(samkhya=3, native=5)=5, not the smaller samkhya estimate"
);
}
#[test]
fn statistics_distinct_count_caps_at_max_of_samkhya_and_native() {
let base = ColumnStatistics {
null_count: Precision::Absent,
max_value: Precision::Absent,
min_value: Precision::Absent,
sum_value: Precision::Absent,
distinct_count: Precision::Inexact(1000),
};
let ovr = ColumnStatistics {
null_count: Precision::Absent,
max_value: Precision::Absent,
min_value: Precision::Absent,
sum_value: Precision::Absent,
distinct_count: Precision::Inexact(50),
};
let merged = merge_column_stats(base, ovr);
assert_eq!(
merged.distinct_count,
Precision::Inexact(1000),
"merge must publish max(samkhya, native) distinct_count"
);
}
}