use std::any::Any;
use std::collections::HashMap;
use std::fmt::{self, Debug};
use std::sync::Arc;
use crate::datasource::{TableProvider, TableType};
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::insert::{DataSink, DataSinkExec};
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::{
common, DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties,
Partitioning, SendableRecordBatchStream,
};
use crate::physical_planner::create_physical_sort_exprs;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt};
use datafusion_execution::TaskContext;
use datafusion_physical_plan::metrics::MetricsSet;
use async_trait::async_trait;
use futures::StreamExt;
use log::debug;
use parking_lot::Mutex;
use tokio::sync::RwLock;
use tokio::task::JoinSet;
pub type PartitionData = Arc<RwLock<Vec<RecordBatch>>>;
#[derive(Debug)]
pub struct MemTable {
schema: SchemaRef,
pub(crate) batches: Vec<PartitionData>,
constraints: Constraints,
column_defaults: HashMap<String, Expr>,
pub sort_order: Arc<Mutex<Vec<Vec<Expr>>>>,
}
impl MemTable {
pub fn try_new(schema: SchemaRef, partitions: Vec<Vec<RecordBatch>>) -> Result<Self> {
for batches in partitions.iter().flatten() {
let batches_schema = batches.schema();
if !schema.contains(&batches_schema) {
debug!(
"mem table schema does not contain batches schema. \
Target_schema: {schema:?}. Batches Schema: {batches_schema:?}"
);
return plan_err!("Mismatch between schema and batches");
}
}
Ok(Self {
schema,
batches: partitions
.into_iter()
.map(|e| Arc::new(RwLock::new(e)))
.collect::<Vec<_>>(),
constraints: Constraints::empty(),
column_defaults: HashMap::new(),
sort_order: Arc::new(Mutex::new(vec![])),
})
}
pub fn with_constraints(mut self, constraints: Constraints) -> Self {
self.constraints = constraints;
self
}
pub fn with_column_defaults(
mut self,
column_defaults: HashMap<String, Expr>,
) -> Self {
self.column_defaults = column_defaults;
self
}
pub fn with_sort_order(self, mut sort_order: Vec<Vec<Expr>>) -> Self {
std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order);
self
}
pub async fn load(
t: Arc<dyn TableProvider>,
output_partitions: Option<usize>,
state: &SessionState,
) -> Result<Self> {
let schema = t.schema();
let exec = t.scan(state, None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();
let mut join_set = JoinSet::new();
for part_idx in 0..partition_count {
let task = state.task_ctx();
let exec = exec.clone();
join_set.spawn(async move {
let stream = exec.execute(part_idx, task)?;
common::collect(stream).await
});
}
let mut data: Vec<Vec<RecordBatch>> =
Vec::with_capacity(exec.output_partitioning().partition_count());
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => data.push(res?),
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}
let exec = MemoryExec::try_new(&data, schema.clone(), None)?;
if let Some(num_partitions) = output_partitions {
let exec = RepartitionExec::try_new(
Arc::new(exec),
Partitioning::RoundRobinBatch(num_partitions),
)?;
let mut output_partitions = vec![];
for i in 0..exec.properties().output_partitioning().partition_count() {
let task_ctx = state.task_ctx();
let mut stream = exec.execute(i, task_ctx)?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}
return MemTable::try_new(schema.clone(), output_partitions);
}
MemTable::try_new(schema.clone(), data)
}
}
#[async_trait]
impl TableProvider for MemTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn constraints(&self) -> Option<&Constraints> {
Some(&self.constraints)
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut partitions = vec![];
for arc_inner_vec in self.batches.iter() {
let inner_vec = arc_inner_vec.read().await;
partitions.push(inner_vec.clone())
}
let mut exec =
MemoryExec::try_new(&partitions, self.schema(), projection.cloned())?;
let show_sizes = state.config_options().explain.show_sizes;
exec = exec.with_show_sizes(show_sizes);
let sort_order = self.sort_order.lock();
if !sort_order.is_empty() {
let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
let file_sort_order = sort_order
.iter()
.map(|sort_exprs| {
create_physical_sort_exprs(
sort_exprs,
&df_schema,
state.execution_props(),
)
})
.collect::<Result<Vec<_>>>()?;
exec = exec.with_sort_information(file_sort_order);
}
Ok(Arc::new(exec))
}
async fn insert_into(
&self,
_state: &SessionState,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
*self.sort_order.lock() = vec![];
if !self
.schema()
.logically_equivalent_names_and_types(&input.schema())
{
return plan_err!(
"Inserting query must have the same schema with the table."
);
}
if overwrite {
return not_impl_err!("Overwrite not implemented for MemoryTable yet");
}
let sink = Arc::new(MemSink::new(self.batches.clone()));
Ok(Arc::new(DataSinkExec::new(
input,
sink,
self.schema.clone(),
None,
)))
}
fn get_column_default(&self, column: &str) -> Option<&Expr> {
self.column_defaults.get(column)
}
}
struct MemSink {
batches: Vec<PartitionData>,
}
impl Debug for MemSink {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MemSink")
.field("num_partitions", &self.batches.len())
.finish()
}
}
impl DisplayAs for MemSink {
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let partition_count = self.batches.len();
write!(f, "MemoryTable (partitions={partition_count})")
}
}
}
}
impl MemSink {
fn new(batches: Vec<PartitionData>) -> Self {
Self { batches }
}
}
#[async_trait]
impl DataSink for MemSink {
fn as_any(&self) -> &dyn Any {
self
}
fn metrics(&self) -> Option<MetricsSet> {
None
}
async fn write_all(
&self,
mut data: SendableRecordBatchStream,
_context: &Arc<TaskContext>,
) -> Result<u64> {
let num_partitions = self.batches.len();
let mut new_batches = vec![vec![]; num_partitions];
let mut i = 0;
let mut row_count = 0;
while let Some(batch) = data.next().await.transpose()? {
row_count += batch.num_rows();
new_batches[i].push(batch);
i = (i + 1) % num_partitions;
}
for (target, mut batches) in self.batches.iter().zip(new_batches.into_iter()) {
target.write().await.append(&mut batches);
}
Ok(row_count as u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::provider_as_source;
use crate::physical_plan::collect;
use crate::prelude::SessionContext;
use arrow::array::{AsArray, Int32Array};
use arrow::datatypes::{DataType, Field, Schema, UInt64Type};
use arrow::error::ArrowError;
use datafusion_common::DataFusionError;
use datafusion_expr::LogicalPlanBuilder;
#[tokio::test]
async fn test_with_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
Field::new("d", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
Arc::new(Int32Array::from(vec![None, None, Some(9)])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let exec = provider
.scan(&session_ctx.state(), Some(&vec![2, 1]), &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
assert_eq!("c", batch2.schema().field(0).name());
assert_eq!("b", batch2.schema().field(1).name());
assert_eq!(2, batch2.num_columns());
Ok(())
}
#[tokio::test]
async fn test_without_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Ok(())
}
#[tokio::test]
async fn test_invalid_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let projection: Vec<usize> = vec![0, 4];
match provider
.scan(&session_ctx.state(), Some(&projection), &[], None)
.await
{
Err(DataFusionError::ArrowError(ArrowError::SchemaError(e), _)) => {
assert_eq!(
"\"project index 4 out of bounds, max field 3\"",
format!("{e:?}")
)
}
res => panic!("Scan should failed on invalid projection, got {res:?}"),
};
Ok(())
}
#[test]
fn test_schema_validation_incompatible_column() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
let e = MemTable::try_new(schema2, vec![vec![batch]]).unwrap_err();
assert_eq!(
"Error during planning: Mismatch between schema and batches",
e.strip_backtrace()
);
Ok(())
}
#[test]
fn test_schema_validation_different_column_count() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![7, 5, 9])),
],
)?;
let e = MemTable::try_new(schema2, vec![vec![batch]]).unwrap_err();
assert_eq!(
"Error during planning: Mismatch between schema and batches",
e.strip_backtrace()
);
Ok(())
}
#[tokio::test]
async fn test_merged_schema() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let mut metadata = HashMap::new();
metadata.insert("foo".to_string(), "bar".to_string());
let schema1 = Schema::new_with_metadata(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
],
metadata,
);
let schema2 = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);
let merged_schema = Schema::try_merge(vec![schema1.clone(), schema2.clone()])?;
let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;
let provider =
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;
let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Ok(())
}
async fn experiment(
schema: SchemaRef,
initial_data: Vec<Vec<RecordBatch>>,
inserted_data: Vec<Vec<RecordBatch>>,
) -> Result<Vec<Vec<RecordBatch>>> {
let expected_count: u64 = inserted_data
.iter()
.flat_map(|batches| batches.iter().map(|batch| batch.num_rows() as u64))
.sum();
let session_ctx = SessionContext::new();
let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?);
session_ctx.register_table("t", initial_table.clone())?;
let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?);
session_ctx.register_table("source", source_table.clone())?;
let source = provider_as_source(source_table);
let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?;
let insert_into_table =
LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?;
let plan = session_ctx
.state()
.create_physical_plan(&insert_into_table)
.await?;
let res = collect(plan, session_ctx.task_ctx()).await?;
assert_eq!(extract_count(res), expected_count);
let mut partitions = vec![];
for partition in initial_table.batches.iter() {
let part = partition.read().await.clone();
partitions.push(part);
}
Ok(partitions)
}
fn extract_count(res: Vec<RecordBatch>) -> u64 {
assert_eq!(res.len(), 1, "expected one batch, got {}", res.len());
let batch = &res[0];
assert_eq!(
batch.num_columns(),
1,
"expected 1 column, got {}",
batch.num_columns()
);
let col = batch.column(0).as_primitive::<UInt64Type>();
assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len());
let val = col
.iter()
.next()
.expect("had value")
.expect("expected non null");
val
}
#[tokio::test]
async fn test_insert_into_single_partition() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let resulting_data_in_table =
experiment(schema, vec![vec![batch.clone()]], vec![vec![batch.clone()]])
.await?;
assert_eq!(resulting_data_in_table[0].len(), 2);
Ok(())
}
#[tokio::test]
async fn test_insert_into_single_partition_with_multi_partition() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let resulting_data_in_table = experiment(
schema,
vec![vec![batch.clone()]],
vec![vec![batch.clone()], vec![batch]],
)
.await?;
assert_eq!(resulting_data_in_table[0].len(), 3);
Ok(())
}
#[tokio::test]
async fn test_insert_into_multi_partition_with_multi_partition() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let resulting_data_in_table = experiment(
schema,
vec![vec![batch.clone()], vec![batch.clone()]],
vec![
vec![batch.clone(), batch.clone()],
vec![batch.clone(), batch],
],
)
.await?;
assert_eq!(resulting_data_in_table[0].len(), 3);
assert_eq!(resulting_data_in_table[1].len(), 3);
Ok(())
}
#[tokio::test]
async fn test_insert_from_empty_table() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let resulting_data_in_table = experiment(
schema,
vec![vec![batch.clone(), batch.clone()]],
vec![vec![]],
)
.await?;
assert_eq!(resulting_data_in_table[0].len(), 2);
Ok(())
}
}