use std::any::Any;
use std::fmt;
use std::sync::Arc;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{Result, Statistics};
use datafusion::execution::TaskContext;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
};
#[derive(Debug, Clone)]
pub struct SamkhyaStatsExec {
input: Arc<dyn ExecutionPlan>,
stats: Statistics,
}
impl SamkhyaStatsExec {
pub fn new(input: Arc<dyn ExecutionPlan>, stats: Statistics) -> Self {
Self { input, stats }
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn override_statistics(&self) -> &Statistics {
&self.stats
}
}
impl DisplayAs for SamkhyaStatsExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "SamkhyaStatsExec: num_rows={:?}", self.stats.num_rows)
}
}
}
}
impl ExecutionPlan for SamkhyaStatsExec {
fn name(&self) -> &str {
"SamkhyaStatsExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.input.schema()
}
fn properties(&self) -> &PlanProperties {
self.input.properties()
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![true]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let new_input = children
.into_iter()
.next()
.expect("SamkhyaStatsExec has exactly one child");
Ok(Arc::new(SamkhyaStatsExec::new(
new_input,
self.stats.clone(),
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
self.input.execute(partition, context)
}
fn statistics(&self) -> Result<Statistics> {
Ok(self.stats.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::array::Int64Array;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::stats::Precision;
use datafusion::datasource::{MemTable, TableProvider};
use datafusion::execution::context::SessionContext;
async fn tiny_input_exec() -> Arc<dyn ExecutionPlan> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int64Array::from(vec![1, 2, 3]))],
)
.unwrap();
let mem = Arc::new(MemTable::try_new(Arc::clone(&schema), vec![vec![batch]]).unwrap());
let ctx = SessionContext::new();
let state = ctx.state();
let session: &dyn datafusion::catalog::Session = &state;
mem.scan(session, None, &[], None).await.unwrap()
}
#[tokio::test(flavor = "multi_thread")]
async fn wrapper_reports_override_stats() {
let inner = tiny_input_exec().await;
let mut stats = Statistics::new_unknown(inner.schema().as_ref());
stats.num_rows = Precision::Inexact(42);
let wrapped: Arc<dyn ExecutionPlan> = Arc::new(SamkhyaStatsExec::new(inner, stats));
let s = wrapped.statistics().expect("stats present");
assert_eq!(s.num_rows, Precision::Inexact(42));
}
#[tokio::test(flavor = "multi_thread")]
async fn with_new_children_preserves_override() {
let inner = tiny_input_exec().await;
let mut stats = Statistics::new_unknown(inner.schema().as_ref());
stats.num_rows = Precision::Inexact(7);
let wrapped: Arc<dyn ExecutionPlan> =
Arc::new(SamkhyaStatsExec::new(Arc::clone(&inner), stats));
let rebuilt = Arc::clone(&wrapped)
.with_new_children(vec![inner])
.expect("rebuild");
assert_eq!(
rebuilt.statistics().unwrap().num_rows,
Precision::Inexact(7)
);
}
}