use futures::{StreamExt, TryStreamExt};
use std::any::Any;
use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_expr::LogicalPlan;
use tokio::sync::RwLock;
use tokio::task;
use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use crate::physical_plan::common;
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
#[derive(Debug)]
pub struct MemTable {
schema: SchemaRef,
batches: Arc<RwLock<Vec<Vec<RecordBatch>>>>,
}
impl MemTable {
pub fn try_new(schema: SchemaRef, partitions: Vec<Vec<RecordBatch>>) -> Result<Self> {
if partitions
.iter()
.flatten()
.all(|batches| schema.contains(&batches.schema()))
{
Ok(Self {
schema,
batches: Arc::new(RwLock::new(partitions)),
})
} else {
Err(DataFusionError::Plan(
"Mismatch between schema and batches".to_string(),
))
}
}
pub async fn load(
t: Arc<dyn TableProvider>,
output_partitions: Option<usize>,
state: &SessionState,
) -> Result<Self> {
let schema = t.schema();
let exec = t.scan(state, None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();
let tasks = (0..partition_count)
.map(|part_i| {
let task = state.task_ctx();
let exec = exec.clone();
let task = tokio::spawn(async move {
let stream = exec.execute(part_i, task)?;
common::collect(stream).await
});
AbortOnDropSingle::new(task)
})
.collect::<Vec<_>>();
let mut data: Vec<Vec<RecordBatch>> =
Vec::with_capacity(exec.output_partitioning().partition_count());
for result in futures::future::join_all(tasks).await {
data.push(result.map_err(|e| DataFusionError::External(Box::new(e)))??)
}
let exec = MemoryExec::try_new(&data, schema.clone(), None)?;
if let Some(num_partitions) = output_partitions {
let exec = RepartitionExec::try_new(
Arc::new(exec),
Partitioning::RoundRobinBatch(num_partitions),
)?;
let mut output_partitions = vec![];
for i in 0..exec.output_partitioning().partition_count() {
let task_ctx = state.task_ctx();
let mut stream = exec.execute(i, task_ctx)?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}
return MemTable::try_new(schema.clone(), output_partitions);
}
MemTable::try_new(schema.clone(), data)
}
}
#[async_trait]
impl TableProvider for MemTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let batches = &self.batches.read().await;
Ok(Arc::new(MemoryExec::try_new(
batches,
self.schema(),
projection.cloned(),
)?))
}
async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> {
let plan = state.create_physical_plan(input).await?;
if !plan.schema().eq(&self.schema) {
return Err(DataFusionError::Plan(
"Inserting query must have the same schema with the table.".to_string(),
));
}
let plan_partition_count = plan.output_partitioning().partition_count();
let table_partition_count = self.batches.read().await.len();
let plan: Arc<dyn ExecutionPlan> = if plan_partition_count
== table_partition_count
|| table_partition_count == 0
{
plan
} else if table_partition_count == 1 {
Arc::new(CoalescePartitionsExec::new(plan))
} else {
Arc::new(RepartitionExec::try_new(
plan,
Partitioning::RoundRobinBatch(table_partition_count),
)?)
};
let task_ctx = state.task_ctx();
let mut tasks = vec![];
for idx in 0..plan.output_partitioning().partition_count() {
let stream = plan.execute(idx, task_ctx.clone())?;
let handle = task::spawn(async move {
stream.try_collect().await.map_err(DataFusionError::from)
});
tasks.push(AbortOnDropSingle::new(handle));
}
let results = futures::future::join_all(tasks)
.await
.into_iter()
.map(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})
.collect::<Result<Vec<Vec<RecordBatch>>>>()?;
let mut all_batches = self.batches.write().await;
if all_batches.is_empty() {
*all_batches = results
} else {
for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) {
batches.extend(result);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::provider_as_source;
use crate::from_slice::FromSlice;
use crate::prelude::SessionContext;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use datafusion_expr::LogicalPlanBuilder;
use futures::StreamExt;
use std::collections::HashMap;
#[tokio::test]
async fn test_with_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
Field::new("d", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_slice([1, 2, 3])),
Arc::new(Int32Array::from_slice([4, 5, 6])),
Arc::new(Int32Array::from_slice([7, 8, 9])),
Arc::new(Int32Array::from(vec![None, None, Some(9)])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let exec = provider
.scan(&session_ctx.state(), Some(&vec![2, 1]), &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
assert_eq!("c", batch2.schema().field(0).name());
assert_eq!("b", batch2.schema().field(1).name());
assert_eq!(2, batch2.num_columns());
Ok(())
}
#[tokio::test]
async fn test_without_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_slice([1, 2, 3])),
Arc::new(Int32Array::from_slice([4, 5, 6])),
Arc::new(Int32Array::from_slice([7, 8, 9])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Ok(())
}
#[tokio::test]
async fn test_invalid_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_slice([1, 2, 3])),
Arc::new(Int32Array::from_slice([4, 5, 6])),
Arc::new(Int32Array::from_slice([7, 8, 9])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let projection: Vec<usize> = vec![0, 4];
match provider
.scan(&session_ctx.state(), Some(&projection), &[], None)
.await
{
Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => {
assert_eq!(
"\"project index 4 out of bounds, max field 3\"",
format!("{e:?}")
)
}
res => panic!("Scan should failed on invalid projection, got {res:?}"),
};
Ok(())
}
#[test]
fn test_schema_validation_incompatible_column() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(Int32Array::from_slice([1, 2, 3])),
Arc::new(Int32Array::from_slice([4, 5, 6])),
Arc::new(Int32Array::from_slice([7, 8, 9])),
],
)?;
match MemTable::try_new(schema2, vec![vec![batch]]) {
Err(DataFusionError::Plan(e)) => {
assert_eq!("\"Mismatch between schema and batches\"", format!("{e:?}"))
}
_ => panic!("MemTable::new should have failed due to schema mismatch"),
}
Ok(())
}
#[test]
fn test_schema_validation_different_column_count() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(Int32Array::from_slice([1, 2, 3])),
Arc::new(Int32Array::from_slice([7, 5, 9])),
],
)?;
match MemTable::try_new(schema2, vec![vec![batch]]) {
Err(DataFusionError::Plan(e)) => {
assert_eq!("\"Mismatch between schema and batches\"", format!("{e:?}"))
}
_ => panic!("MemTable::new should have failed due to schema mismatch"),
}
Ok(())
}
#[tokio::test]
async fn test_merged_schema() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let mut metadata = HashMap::new();
metadata.insert("foo".to_string(), "bar".to_string());
let schema1 = Schema::new_with_metadata(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
],
metadata,
);
let schema2 = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);
let merged_schema = Schema::try_merge(vec![schema1.clone(), schema2.clone()])?;
let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![
Arc::new(Int32Array::from_slice([1, 2, 3])),
Arc::new(Int32Array::from_slice([4, 5, 6])),
Arc::new(Int32Array::from_slice([7, 8, 9])),
],
)?;
let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![
Arc::new(Int32Array::from_slice([1, 2, 3])),
Arc::new(Int32Array::from_slice([4, 5, 6])),
Arc::new(Int32Array::from_slice([7, 8, 9])),
],
)?;
let provider =
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;
let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Ok(())
}
fn create_mem_table_scan(
schema: SchemaRef,
data: Vec<Vec<RecordBatch>>,
) -> Result<Arc<LogicalPlan>> {
let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?));
Ok(Arc::new(
LogicalPlanBuilder::scan("source", provider, None)?.build()?,
))
}
fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> {
let session_ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
)?;
Ok((session_ctx, schema, batch))
}
#[tokio::test]
async fn test_insert_into_single_partition() -> Result<()> {
let (session_ctx, schema, batch) = create_initial_ctx()?;
let initial_table = Arc::new(MemTable::try_new(
schema.clone(),
vec![vec![batch.clone()]],
)?);
let single_partition_table_scan =
create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?;
initial_table
.insert_into(&session_ctx.state(), &single_partition_table_scan)
.await?;
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
let multi_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone()], vec![batch]],
)?;
initial_table
.insert_into(&session_ctx.state(), &multi_partition_table_scan)
.await?;
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
assert_eq!(initial_table.batches.read().await.len(), 1);
Ok(())
}
#[tokio::test]
async fn test_insert_into_multiple_partition() -> Result<()> {
let (session_ctx, schema, batch) = create_initial_ctx()?;
let initial_table = Arc::new(MemTable::try_new(
schema.clone(),
vec![vec![batch.clone()], vec![batch.clone()]],
)?);
let single_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone(), batch.clone()]],
)?;
initial_table
.insert_into(&session_ctx.state(), &single_partition_table_scan)
.await?;
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2);
let multi_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone()], vec![batch]],
)?;
initial_table
.insert_into(&session_ctx.state(), &multi_partition_table_scan)
.await?;
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3);
assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3);
Ok(())
}
#[tokio::test]
async fn test_insert_into_empty_table() -> Result<()> {
let (session_ctx, schema, batch) = create_initial_ctx()?;
let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
let single_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone(), batch.clone()]],
)?;
initial_table
.insert_into(&session_ctx.state(), &single_partition_table_scan)
.await?;
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
let single_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone()], vec![batch]],
)?;
initial_table
.insert_into(&session_ctx.state(), &single_partition_table_scan)
.await?;
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
Ok(())
}
}