use futures::StreamExt;
use log::debug;
use std::any::Any;
use std::fmt::{self, Debug};
use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_execution::TaskContext;
use tokio::sync::RwLock;
use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::insert::{DataSink, InsertExec};
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::{common, SendableRecordBatchStream};
use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
pub type PartitionData = Arc<RwLock<Vec<RecordBatch>>>;
#[derive(Debug)]
pub struct MemTable {
schema: SchemaRef,
pub(crate) batches: Vec<PartitionData>,
}
impl MemTable {
pub fn try_new(schema: SchemaRef, partitions: Vec<Vec<RecordBatch>>) -> Result<Self> {
for batches in partitions.iter().flatten() {
let batches_schema = batches.schema();
if !schema.contains(&batches_schema) {
debug!(
"mem table schema does not contain batches schema. \
Target_schema: {schema:?}. Batches Schema: {batches_schema:?}"
);
return Err(DataFusionError::Plan(
"Mismatch between schema and batches".to_string(),
));
}
}
Ok(Self {
schema,
batches: partitions
.into_iter()
.map(|e| Arc::new(RwLock::new(e)))
.collect::<Vec<_>>(),
})
}
pub async fn load(
t: Arc<dyn TableProvider>,
output_partitions: Option<usize>,
state: &SessionState,
) -> Result<Self> {
let schema = t.schema();
let exec = t.scan(state, None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();
let tasks = (0..partition_count)
.map(|part_i| {
let task = state.task_ctx();
let exec = exec.clone();
let task = tokio::spawn(async move {
let stream = exec.execute(part_i, task)?;
common::collect(stream).await
});
AbortOnDropSingle::new(task)
})
.collect::<Vec<_>>();
let mut data: Vec<Vec<RecordBatch>> =
Vec::with_capacity(exec.output_partitioning().partition_count());
for result in futures::future::join_all(tasks).await {
data.push(result.map_err(|e| DataFusionError::External(Box::new(e)))??)
}
let exec = MemoryExec::try_new(&data, schema.clone(), None)?;
if let Some(num_partitions) = output_partitions {
let exec = RepartitionExec::try_new(
Arc::new(exec),
Partitioning::RoundRobinBatch(num_partitions),
)?;
let mut output_partitions = vec![];
for i in 0..exec.output_partitioning().partition_count() {
let task_ctx = state.task_ctx();
let mut stream = exec.execute(i, task_ctx)?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}
return MemTable::try_new(schema.clone(), output_partitions);
}
MemTable::try_new(schema.clone(), data)
}
}
#[async_trait]
impl TableProvider for MemTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut partitions = vec![];
for arc_inner_vec in self.batches.iter() {
let inner_vec = arc_inner_vec.read().await;
partitions.push(inner_vec.clone())
}
Ok(Arc::new(MemoryExec::try_new_owned_data(
partitions,
self.schema(),
projection.cloned(),
)?))
}
async fn insert_into(
&self,
_state: &SessionState,
input: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>> {
if !input.schema().eq(&self.schema) {
return Err(DataFusionError::Plan(
"Inserting query must have the same schema with the table.".to_string(),
));
}
let sink = Arc::new(MemSink::new(self.batches.clone()));
Ok(Arc::new(InsertExec::new(input, sink)))
}
}
struct MemSink {
batches: Vec<PartitionData>,
}
impl Debug for MemSink {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MemSink")
.field("num_partitions", &self.batches.len())
.finish()
}
}
impl DisplayAs for MemSink {
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let partition_count = self.batches.len();
write!(f, "MemoryTable (partitions={partition_count})")
}
}
}
}
impl MemSink {
fn new(batches: Vec<PartitionData>) -> Self {
Self { batches }
}
}
#[async_trait]
impl DataSink for MemSink {
async fn write_all(
&self,
mut data: SendableRecordBatchStream,
_context: &Arc<TaskContext>,
) -> Result<u64> {
let num_partitions = self.batches.len();
let mut new_batches = vec![vec![]; num_partitions];
let mut i = 0;
let mut row_count = 0;
while let Some(batch) = data.next().await.transpose()? {
row_count += batch.num_rows();
new_batches[i].push(batch);
i = (i + 1) % num_partitions;
}
for (target, mut batches) in self.batches.iter().zip(new_batches.into_iter()) {
target.write().await.append(&mut batches);
}
Ok(row_count as u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::provider_as_source;
use crate::physical_plan::collect;
use crate::prelude::SessionContext;
use arrow::array::{AsArray, Int32Array};
use arrow::datatypes::{DataType, Field, Schema, UInt64Type};
use arrow::error::ArrowError;
use datafusion_expr::LogicalPlanBuilder;
use futures::StreamExt;
use std::collections::HashMap;
#[tokio::test]
async fn test_with_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
Field::new("d", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
Arc::new(Int32Array::from(vec![None, None, Some(9)])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let exec = provider
.scan(&session_ctx.state(), Some(&vec![2, 1]), &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
assert_eq!("c", batch2.schema().field(0).name());
assert_eq!("b", batch2.schema().field(1).name());
assert_eq!(2, batch2.num_columns());
Ok(())
}
#[tokio::test]
async fn test_without_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Ok(())
}
#[tokio::test]
async fn test_invalid_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let projection: Vec<usize> = vec![0, 4];
match provider
.scan(&session_ctx.state(), Some(&projection), &[], None)
.await
{
Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => {
assert_eq!(
"\"project index 4 out of bounds, max field 3\"",
format!("{e:?}")
)
}
res => panic!("Scan should failed on invalid projection, got {res:?}"),
};
Ok(())
}
#[test]
fn test_schema_validation_incompatible_column() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
match MemTable::try_new(schema2, vec![vec![batch]]) {
Err(DataFusionError::Plan(e)) => {
assert_eq!("\"Mismatch between schema and batches\"", format!("{e:?}"))
}
_ => panic!("MemTable::new should have failed due to schema mismatch"),
}
Ok(())
}
#[test]
fn test_schema_validation_different_column_count() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![7, 5, 9])),
],
)?;
match MemTable::try_new(schema2, vec![vec![batch]]) {
Err(DataFusionError::Plan(e)) => {
assert_eq!("\"Mismatch between schema and batches\"", format!("{e:?}"))
}
_ => panic!("MemTable::new should have failed due to schema mismatch"),
}
Ok(())
}
#[tokio::test]
async fn test_merged_schema() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let mut metadata = HashMap::new();
metadata.insert("foo".to_string(), "bar".to_string());
let schema1 = Schema::new_with_metadata(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
],
metadata,
);
let schema2 = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);
let merged_schema = Schema::try_merge(vec![schema1.clone(), schema2.clone()])?;
let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
let provider =
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;
let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Ok(())
}
async fn experiment(
schema: SchemaRef,
initial_data: Vec<Vec<RecordBatch>>,
inserted_data: Vec<Vec<RecordBatch>>,
) -> Result<Vec<Vec<RecordBatch>>> {
let expected_count: u64 = inserted_data
.iter()
.flat_map(|batches| batches.iter().map(|batch| batch.num_rows() as u64))
.sum();
let session_ctx = SessionContext::new();
let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?);
session_ctx.register_table("t", initial_table.clone())?;
let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?);
session_ctx.register_table("source", source_table.clone())?;
let source = provider_as_source(source_table);
let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?;
let insert_into_table =
LogicalPlanBuilder::insert_into(scan_plan, "t", &schema)?.build()?;
let plan = session_ctx
.state()
.create_physical_plan(&insert_into_table)
.await?;
let res = collect(plan, session_ctx.task_ctx()).await?;
assert_eq!(extract_count(res), expected_count);
let mut partitions = vec![];
for partition in initial_table.batches.iter() {
let part = partition.read().await.clone();
partitions.push(part);
}
Ok(partitions)
}
fn extract_count(res: Vec<RecordBatch>) -> u64 {
assert_eq!(res.len(), 1, "expected one batch, got {}", res.len());
let batch = &res[0];
assert_eq!(
batch.num_columns(),
1,
"expected 1 column, got {}",
batch.num_columns()
);
let col = batch.column(0).as_primitive::<UInt64Type>();
assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len());
let val = col
.iter()
.next()
.expect("had value")
.expect("expected non null");
val
}
#[tokio::test]
async fn test_insert_into_single_partition() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let resulting_data_in_table =
experiment(schema, vec![vec![batch.clone()]], vec![vec![batch.clone()]])
.await?;
assert_eq!(resulting_data_in_table[0].len(), 2);
Ok(())
}
#[tokio::test]
async fn test_insert_into_single_partition_with_multi_partition() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let resulting_data_in_table = experiment(
schema,
vec![vec![batch.clone()]],
vec![vec![batch.clone()], vec![batch]],
)
.await?;
assert_eq!(resulting_data_in_table[0].len(), 3);
Ok(())
}
#[tokio::test]
async fn test_insert_into_multi_partition_with_multi_partition() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let resulting_data_in_table = experiment(
schema,
vec![vec![batch.clone()], vec![batch.clone()]],
vec![
vec![batch.clone(), batch.clone()],
vec![batch.clone(), batch],
],
)
.await?;
assert_eq!(resulting_data_in_table[0].len(), 3);
assert_eq!(resulting_data_in_table[1].len(), 3);
Ok(())
}
#[tokio::test]
async fn test_insert_from_empty_table() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let resulting_data_in_table = experiment(
schema,
vec![vec![batch.clone(), batch.clone()]],
vec![vec![]],
)
.await?;
assert_eq!(resulting_data_in_table[0].len(), 2);
Ok(())
}
}