use std::sync::Arc;
use datafusion::arrow::array::{Int64Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::stats::Precision;
use datafusion::datasource::{MemTable, TableProvider};
use datafusion::execution::context::SessionContext;
use samkhya_core::stats::ColumnStats;
use samkhya_datafusion::SamkhyaTableProvider;
fn build_inner_table() -> Arc<MemTable> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])),
],
)
.expect("record batch");
Arc::new(MemTable::try_new(schema, vec![vec![batch]]).expect("mem table"))
}
#[tokio::test(flavor = "multi_thread")]
async fn wrapper_injects_corrected_row_count() {
let inner = build_inner_table();
let wrapped = Arc::new(
SamkhyaTableProvider::new(inner).with_column_stats(
0,
ColumnStats::new()
.with_row_count(999)
.with_distinct_count(42),
),
);
let wrapped_handle: Arc<SamkhyaTableProvider> = Arc::clone(&wrapped);
let ctx = SessionContext::new();
ctx.register_table("t", wrapped as Arc<dyn TableProvider>)
.expect("register wrapped provider");
let df = ctx
.sql("SELECT * FROM t")
.await
.expect("SELECT * should plan");
let batches = df.collect().await.expect("SELECT * should execute");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(
total_rows, 5,
"execution-time row count must equal the MemTable's actual rows"
);
let explain_df = ctx
.sql("EXPLAIN VERBOSE SELECT * FROM t")
.await
.expect("EXPLAIN VERBOSE should plan");
let explain_batches = explain_df
.collect()
.await
.expect("EXPLAIN VERBOSE should execute");
assert!(
!explain_batches.is_empty(),
"EXPLAIN VERBOSE produced no output"
);
let calls_before_direct = wrapped_handle.stats_call_count();
let provider_trait: &dyn TableProvider = wrapped_handle.as_ref();
let stats = provider_trait
.statistics()
.expect("wrapper must return Some(Statistics) via the TableProvider trait");
assert_eq!(
stats.num_rows,
Precision::Inexact(999),
"corrected row count must be 999, not the MemTable's actual count"
);
assert_eq!(
stats.column_statistics[0].distinct_count,
Precision::Inexact(42),
"corrected distinct count must be 42"
);
assert_eq!(
wrapped_handle.stats_call_count(),
calls_before_direct + 1,
"TableProvider::statistics() call should increment the counter"
);
assert!(
wrapped_handle.stats_call_count() >= 1,
"wrapper statistics() was never invoked"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn wrapper_without_overrides_returns_skeleton_stats() {
let inner = build_inner_table();
let wrapped = SamkhyaTableProvider::new(inner);
let stats = wrapped.statistics().expect("statistics present");
assert_eq!(stats.column_statistics.len(), 2);
assert_eq!(stats.num_rows, Precision::Absent);
}