use std::any::Any;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt;
use std::sync::Arc;
use async_trait::async_trait;
use datafusion::arrow::array::{
Array, ArrayRef, Int32Array, RecordBatch, StringArray, UInt64Array,
};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::catalog::Session;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::logical_expr::{Expr, Operator, TableProviderFilterPushDown};
use datafusion::physical_plan::memory::MemoryStream;
use datafusion::physical_plan::{ExecutionPlan, Statistics};
use datafusion::scalar::ScalarValue;
use datafusion_common::DataFusionError;
use datafusion_index_provider::physical_plan::exec::fetch::RecordFetchExec;
use datafusion_index_provider::physical_plan::{create_index_schema, Index};
use datafusion_index_provider::{IndexedTableProvider, RecordFetcher, UnionMode};
#[derive(Debug)]
struct AgeIndex {
index: BTreeMap<i32, Vec<i32>>,
}
impl AgeIndex {
fn new(ages: &Int32Array, ids: &Int32Array) -> Self {
let mut index: BTreeMap<i32, Vec<i32>> = BTreeMap::new();
for i in 0..ages.len() {
index.entry(ages.value(i)).or_default().push(ids.value(i));
}
Self { index }
}
fn matching_ids(&self, filters: &[Expr], limit: Option<usize>) -> Vec<u64> {
let mut ids: BTreeSet<i32> = BTreeSet::new();
for filter in filters {
if let Expr::BinaryExpr(be) = filter {
if let (Expr::Column(c), Expr::Literal(ScalarValue::Int32(Some(v)), _)) =
(be.left.as_ref(), be.right.as_ref())
{
if c.name != "age" {
continue;
}
match be.op {
Operator::Eq => {
if let Some(list) = self.index.get(v) {
ids.extend(list);
}
}
Operator::Gt => {
ids.extend(self.index.range((v + 1)..).flat_map(|(_, l)| l))
}
Operator::GtEq => ids.extend(self.index.range(v..).flat_map(|(_, l)| l)),
Operator::Lt => ids.extend(self.index.range(..v).flat_map(|(_, l)| l)),
Operator::LtEq => ids.extend(self.index.range(..=v).flat_map(|(_, l)| l)),
_ => {}
}
}
}
}
let mut result: Vec<u64> = ids.into_iter().map(|id| id as u64).collect();
if let Some(l) = limit {
result.truncate(l);
}
result
}
}
impl Index for AgeIndex {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"age_index"
}
fn index_schema(&self) -> SchemaRef {
create_index_schema([Field::new("id", DataType::UInt64, false)])
}
fn table_name(&self) -> &str {
"employees"
}
fn column_name(&self) -> &str {
"age"
}
fn scan(
&self,
filters: &[Expr],
limit: Option<usize>,
) -> Result<SendableRecordBatchStream, DataFusionError> {
let ids = self.matching_ids(filters, limit);
let batches = if ids.is_empty() {
vec![]
} else {
let col = Arc::new(UInt64Array::from(ids)) as ArrayRef;
vec![RecordBatch::try_new(self.index_schema(), vec![col])?]
};
Ok(Box::pin(MemoryStream::try_new(
batches,
self.index_schema(),
None,
)?))
}
fn statistics(&self) -> Statistics {
Statistics::new_unknown(self.index_schema().as_ref())
}
}
struct InMemoryFetcher {
batch: RecordBatch,
}
impl fmt::Debug for InMemoryFetcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "InMemoryFetcher")
}
}
#[async_trait]
impl RecordFetcher for InMemoryFetcher {
fn schema(&self) -> SchemaRef {
self.batch.schema()
}
async fn fetch(&self, index_batch: RecordBatch) -> Result<RecordBatch> {
let ids = index_batch
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.expect("expected UInt64Array for primary key column");
let indices = Int32Array::from_iter_values(ids.iter().flatten().map(|id| (id - 1) as i32));
let columns: Result<Vec<ArrayRef>> = self
.batch
.columns()
.iter()
.map(|col| {
Ok(Arc::new(datafusion::arrow::compute::take(
col.as_ref(),
&indices,
None,
)?) as ArrayRef)
})
.collect();
Ok(RecordBatch::try_new(self.batch.schema(), columns?)?)
}
}
#[derive(Debug)]
struct EmployeeTable {
schema: SchemaRef,
age_index: Arc<AgeIndex>,
fetcher: Arc<InMemoryFetcher>,
}
#[async_trait]
impl TableProvider for EmployeeTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
_projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let (indexed, _remaining) = self.analyze_and_optimize_filters(filters)?;
if indexed.is_empty() {
unimplemented!("full-table scan fallback not shown in this example");
}
Ok(Arc::new(RecordFetchExec::try_new(
indexed,
limit,
self.fetcher.clone(),
self.schema.clone(),
UnionMode::Parallel,
)?))
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> Result<Vec<TableProviderFilterPushDown>> {
self.supports_filters_index_pushdown(filters)
}
}
#[async_trait]
impl IndexedTableProvider for EmployeeTable {
fn indexes(&self) -> Result<Vec<Arc<dyn Index + 'static>>, DataFusionError> {
Ok(vec![self.age_index.clone()])
}
}
#[tokio::main]
async fn main() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Field::new("age", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(StringArray::from(vec![
"Alice", "Bob", "Charlie", "David", "Eve",
])),
Arc::new(Int32Array::from(vec![25, 30, 35, 28, 32])),
],
)?;
let ids = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let ages = batch
.column(2)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let provider = EmployeeTable {
schema: schema.clone(),
age_index: Arc::new(AgeIndex::new(ages, ids)),
fetcher: Arc::new(InMemoryFetcher {
batch: batch.clone(),
}),
};
let ctx = SessionContext::new();
ctx.register_table("employees", Arc::new(provider))?;
println!("=== Employees older than 29 (index-accelerated) ===");
let df = ctx.sql("SELECT * FROM employees WHERE age > 29").await?;
df.show().await?;
println!("\n=== Employee with age = 25 ===");
let df = ctx.sql("SELECT * FROM employees WHERE age = 25").await?;
df.show().await?;
Ok(())
}