#![allow(dead_code, unused_imports, clippy::all)]
use std::sync::Arc;
use arrow_array::{Int64Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::execution::context::SessionContext;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::logical_expr::{Expr, LogicalPlanBuilder, col, lit};
use uni_plugin::traits::storage::Storage;
use uni_plugin_builtin::optimizer::{PushdownAwareTable, PushdownNegotiationRule};
use uni_plugin_builtin::storage::MemoryStorage;
use uni_plugin_builtin::storage_table_provider::{StorageFilterPushdown, StorageTableProvider};
fn pushdown_only_ctx() -> SessionContext {
let state = SessionStateBuilder::new()
.with_default_features()
.with_optimizer_rules(vec![Arc::new(PushdownNegotiationRule)])
.build();
SessionContext::new_with_state(state)
}
fn fixture_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]))
}
async fn seed_storage() -> Arc<dyn Storage> {
let storage: Arc<dyn Storage> = Arc::new(MemoryStorage::new());
let schema = fixture_schema();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int64Array::from(vec![1_i64, 2, 3, 4, 5, 6, 7, 8]))],
)
.expect("fixture batch");
storage
.write_batch("mem_table", &batch)
.await
.expect("seed write");
storage
}
async fn register_memtable(ctx: &SessionContext) {
let storage = seed_storage().await;
let provider = Arc::new(StorageTableProvider::new(
storage,
"mem_table".to_owned(),
fixture_schema(),
));
let wrapped = PushdownAwareTable::with_filter(provider, Arc::new(StorageFilterPushdown));
ctx.register_table("mem_table", Arc::new(wrapped))
.expect("register_table");
}
#[tokio::test(flavor = "multi_thread")]
async fn match_mem_table_returns_filtered_rows() {
let ctx = pushdown_only_ctx();
register_memtable(&ctx).await;
let df = ctx
.sql("SELECT x FROM mem_table WHERE x > 5 ORDER BY x")
.await
.expect("sql");
let batches = df.collect().await.expect("collect");
let rendered = pretty_format_batches(&batches).expect("format").to_string();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(
total, 3,
"x > 5 must yield 3 rows (6, 7, 8); got {total}\nbatches:\n{rendered}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn explain_elides_filter_for_encodable_predicate() {
let ctx = pushdown_only_ctx();
register_memtable(&ctx).await;
let df = ctx
.sql("EXPLAIN SELECT x FROM mem_table WHERE x > 5")
.await
.expect("sql");
let batches = df.collect().await.expect("collect");
let rendered = pretty_format_batches(&batches).expect("format").to_string();
assert!(
rendered.contains("StorageScanExec") || rendered.contains("TableScan"),
"EXPLAIN output should reference the scan; got:\n{rendered}"
);
assert!(
!rendered.contains("Filter:"),
"EXPLAIN must NOT contain a `Filter:` node above the scan when \
the predicate is encodable; pushdown elision failed. Plan:\n{rendered}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn explain_keeps_filter_for_inexpressible_predicate() {
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion::logical_expr::LogicalPlan;
let ctx = pushdown_only_ctx();
register_memtable(&ctx).await;
let scan = ctx.table("mem_table").await.expect("mem_table");
let unencodable_predicate: Expr = Expr::BinaryExpr(datafusion::logical_expr::BinaryExpr::new(
Box::new(col("x")),
datafusion::logical_expr::Operator::IsDistinctFrom,
Box::new(lit(5_i64)),
));
let plan = LogicalPlanBuilder::from(scan.into_optimized_plan().expect("optimize scan"))
.filter(unencodable_predicate)
.expect("filter")
.build()
.expect("build");
let state = ctx.state();
let optimized = state.optimize(&plan).expect("optimize");
let mut has_filter = false;
let _ = optimized.apply(|node| {
if matches!(node, LogicalPlan::Filter(_)) {
has_filter = true;
}
Ok::<TreeNodeRecursion, datafusion::error::DataFusionError>(TreeNodeRecursion::Continue)
});
assert!(
has_filter,
"Filter MUST stay above the scan when the predicate is inexpressible; \
negative-guard regression. Optimized plan:\n{optimized:?}"
);
}