use std::{
collections::HashMap,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use arrow::array::{Array, ArrayRef, RecordBatch, builder::UInt64Builder};
use arrow::datatypes::SchemaRef;
use dashmap::DashSet;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
use datafusion::logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore};
use datafusion::physical_expr::{Distribution, PhysicalExpr};
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream,
};
use futures::{Stream, StreamExt};
use crate::{
DeltaTableError,
delta_datafusion::get_path_column,
operations::merge::{TARGET_DELETE_COLUMN, TARGET_INSERT_COLUMN, TARGET_UPDATE_COLUMN},
};
pub(crate) type BarrierSurvivorSet = Arc<DashSet<String>>;
#[derive(Debug)]
pub struct MergeBarrierExec {
input: Arc<dyn ExecutionPlan>,
file_column: Arc<String>,
survivors: BarrierSurvivorSet,
expr: Arc<dyn PhysicalExpr>,
}
impl MergeBarrierExec {
pub fn new(
input: Arc<dyn ExecutionPlan>,
file_column: Arc<String>,
expr: Arc<dyn PhysicalExpr>,
) -> Self {
MergeBarrierExec {
input,
file_column,
survivors: Arc::new(DashSet::new()),
expr,
}
}
pub fn survivors(&self) -> BarrierSurvivorSet {
self.survivors.clone()
}
}
impl ExecutionPlan for MergeBarrierExec {
fn name(&self) -> &str {
Self::static_name()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn schema(&self) -> arrow_schema::SchemaRef {
self.input.schema()
}
fn properties(&self) -> &Arc<datafusion::physical_plan::PlanProperties> {
self.input.properties()
}
fn required_input_distribution(&self) -> Vec<Distribution> {
vec![Distribution::HashPartitioned(vec![self.expr.clone()]); 1]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(DataFusionError::Plan(
"MergeBarrierExec wrong number of children".to_string(),
));
}
Ok(Arc::new(MergeBarrierExec::new(
children[0].clone(),
self.file_column.clone(),
self.expr.clone(),
)))
}
fn execute(
&self,
partition: usize,
context: Arc<datafusion::execution::TaskContext>,
) -> datafusion::common::Result<datafusion::physical_plan::SendableRecordBatchStream> {
let input = self.input.execute(partition, context)?;
Ok(Box::pin(MergeBarrierStream::new(
input,
self.schema(),
self.survivors.clone(),
self.file_column.clone(),
)))
}
}
impl DisplayAs for MergeBarrierExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default
| DisplayFormatType::Verbose
| DisplayFormatType::TreeRender => {
write!(f, "MergeBarrier",)?;
Ok(())
}
}
}
}
#[derive(Debug)]
enum State {
Feed,
Drain,
Finalize,
Abort,
Done,
}
#[derive(Debug)]
enum PartitionBarrierState {
Closed,
Open,
}
#[derive(Debug)]
struct MergeBarrierPartition {
state: PartitionBarrierState,
buffer: Vec<RecordBatch>,
file_name: Option<String>,
}
impl MergeBarrierPartition {
pub fn new(file_name: Option<String>) -> Self {
MergeBarrierPartition {
state: PartitionBarrierState::Closed,
buffer: Vec::new(),
file_name,
}
}
pub fn feed(&mut self, batch: RecordBatch) -> DataFusionResult<()> {
match self.state {
PartitionBarrierState::Closed => {
let delete_count = get_count(&batch, TARGET_DELETE_COLUMN)?;
let update_count = get_count(&batch, TARGET_UPDATE_COLUMN)?;
let insert_count = get_count(&batch, TARGET_INSERT_COLUMN)?;
self.buffer.push(batch);
if insert_count > 0 || update_count > 0 || delete_count > 0 {
self.state = PartitionBarrierState::Open;
}
}
PartitionBarrierState::Open => {
self.buffer.push(batch);
}
}
Ok(())
}
pub fn drain(&mut self) -> Option<RecordBatch> {
match self.state {
PartitionBarrierState::Closed => None,
PartitionBarrierState::Open => self.buffer.pop(),
}
}
}
struct MergeBarrierStream {
schema: SchemaRef,
state: State,
input: SendableRecordBatchStream,
file_column: Arc<String>,
survivors: BarrierSurvivorSet,
map: HashMap<String, usize>,
file_partitions: Vec<MergeBarrierPartition>,
}
impl MergeBarrierStream {
pub fn new(
input: SendableRecordBatchStream,
schema: SchemaRef,
survivors: BarrierSurvivorSet,
file_column: Arc<String>,
) -> Self {
let file_partitions = vec![MergeBarrierPartition::new(None)];
MergeBarrierStream {
schema,
state: State::Feed,
input,
file_column,
survivors,
file_partitions,
map: HashMap::new(),
}
}
}
fn get_count(batch: &RecordBatch, column: &str) -> DataFusionResult<usize> {
batch
.column_by_name(column)
.map(|array| array.null_count())
.ok_or_else(|| {
DataFusionError::External(Box::new(DeltaTableError::Generic(
"Required operation column is missing".to_string(),
)))
})
}
impl Stream for MergeBarrierStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match self.state {
State::Feed => {
match self.input.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
let file_dictionary = get_path_column(&batch, &self.file_column)?;
let mut key_map = Vec::with_capacity(file_dictionary.len());
for file_name in file_dictionary.values().into_iter() {
let key = match file_name {
Some(name) => {
if !self.map.contains_key(name) {
let key = self.file_partitions.len();
let part_stream =
MergeBarrierPartition::new(Some(name.to_string()));
self.file_partitions.push(part_stream);
self.map.insert(name.to_string(), key);
}
*self.map.get(name).unwrap()
}
None => 0,
};
key_map.push(key)
}
let mut indices: Vec<_> =
Vec::with_capacity(self.file_partitions.len());
for _ in 0..self.file_partitions.len() {
indices.push(UInt64Builder::with_capacity(batch.num_rows()));
}
for (idx, key) in file_dictionary.keys().iter().enumerate() {
match key {
Some(value) => {
indices[key_map[value as usize]].append_value(idx as u64)
}
None => indices[0].append_value(idx as u64),
}
}
let batches: Vec<Result<(usize, RecordBatch), DataFusionError>> =
indices
.into_iter()
.enumerate()
.filter_map(|(partition, mut indices)| {
let indices = indices.finish();
(!indices.is_empty()).then_some((partition, indices))
})
.map(move |(partition, indices)| {
let columns = batch
.columns()
.iter()
.map(|c| {
Ok(arrow::compute::take(
c.as_ref(),
&indices,
None,
)?)
})
.collect::<DataFusionResult<Vec<ArrayRef>>>()?;
let batch =
RecordBatch::try_new(batch.schema(), columns).unwrap();
Ok((partition, batch))
})
.collect();
for batch in batches {
match batch {
Ok((partition, batch)) => {
self.file_partitions[partition].feed(batch)?;
}
Err(err) => {
self.state = State::Abort;
return Poll::Ready(Some(Err(err)));
}
}
}
self.state = State::Drain;
continue;
}
Poll::Ready(Some(Err(err))) => {
self.state = State::Abort;
return Poll::Ready(Some(Err(err)));
}
Poll::Ready(None) => {
self.state = State::Finalize;
continue;
}
Poll::Pending => return Poll::Pending,
}
}
State::Drain => {
for part in &mut self.file_partitions {
if let Some(batch) = part.drain() {
return Poll::Ready(Some(Ok(batch)));
}
}
self.state = State::Feed;
continue;
}
State::Finalize => {
for part in &mut self.file_partitions {
if let Some(batch) = part.drain() {
return Poll::Ready(Some(Ok(batch)));
}
}
{
for part in &self.file_partitions {
match part.state {
PartitionBarrierState::Closed => {}
PartitionBarrierState::Open => {
if let Some(file_name) = &part.file_name {
self.survivors.insert(file_name.to_owned());
}
}
}
}
}
self.state = State::Done;
continue;
}
State::Abort => return Poll::Ready(None),
State::Done => return Poll::Ready(None),
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, self.input.size_hint().1)
}
}
impl RecordBatchStream for MergeBarrierStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[derive(Debug, Hash, Eq, PartialEq, PartialOrd)]
pub(crate) struct MergeBarrier {
pub input: LogicalPlan,
pub expr: Expr,
pub file_column: Arc<String>,
}
impl UserDefinedLogicalNodeCore for MergeBarrier {
fn name(&self) -> &str {
"MergeBarrier"
}
fn inputs(&self) -> Vec<&datafusion::logical_expr::LogicalPlan> {
vec![&self.input]
}
fn schema(&self) -> &datafusion::common::DFSchemaRef {
self.input.schema()
}
fn expressions(&self) -> Vec<datafusion::logical_expr::Expr> {
vec![self.expr.clone()]
}
fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "MergeBarrier")
}
fn with_exprs_and_inputs(
&self,
exprs: Vec<datafusion::logical_expr::Expr>,
inputs: Vec<datafusion::logical_expr::LogicalPlan>,
) -> DataFusionResult<Self> {
Ok(MergeBarrier {
input: inputs[0].clone(),
file_column: self.file_column.clone(),
expr: exprs[0].clone(),
})
}
}
pub(crate) fn find_node<T: 'static>(
parent: &Arc<dyn ExecutionPlan>,
) -> Option<Arc<dyn ExecutionPlan>> {
if parent.as_any().downcast_ref::<T>().is_some() {
return Some(parent.to_owned());
}
for child in &parent.children() {
let res = find_node::<T>(child);
if res.is_some() {
return res;
}
}
None
}
#[cfg(test)]
mod tests {
use super::BarrierSurvivorSet;
use crate::operations::merge::MergeBarrierExec;
use crate::operations::merge::{
TARGET_DELETE_COLUMN, TARGET_INSERT_COLUMN, TARGET_UPDATE_COLUMN,
};
use arrow::datatypes::Schema as ArrowSchema;
use arrow_array::RecordBatch;
use arrow_array::StringArray;
use arrow_array::{DictionaryArray, UInt16Array};
use arrow_schema::DataType as ArrowDataType;
use arrow_schema::Field;
use datafusion::assert_batches_sorted_eq;
use datafusion::datasource::memory::MemorySourceConfig;
use datafusion::execution::TaskContext;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use futures::StreamExt;
use std::sync::Arc;
#[tokio::test]
async fn test_barrier() {
let schema = get_schema();
let keys = UInt16Array::from(vec![Some(0), Some(1), Some(2), None]);
let values = StringArray::from(vec![Some("file0"), Some("file1"), Some("file2")]);
let dict = DictionaryArray::new(keys, Arc::new(values));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::StringArray::from(vec!["0", "1", "2", "3"])),
Arc::new(dict),
Arc::new(arrow::array::BooleanArray::from(vec![
Some(false),
Some(false),
Some(false),
None,
])),
Arc::new(arrow::array::BooleanArray::from(vec![
Some(false),
None,
Some(false),
Some(false),
])),
Arc::new(arrow::array::BooleanArray::from(vec![
Some(false),
Some(false),
None,
Some(false),
])),
],
)
.unwrap();
let (actual, survivors) = execute(vec![batch]).await;
let expected = vec![
"+----+-----------------+--------------------------+--------------------------+--------------------------+",
"| id | __delta_rs_path | __delta_rs_target_insert | __delta_rs_target_update | __delta_rs_target_delete |",
"+----+-----------------+--------------------------+--------------------------+--------------------------+",
"| 1 | file1 | false | | false |",
"| 2 | file2 | false | false | |",
"| 3 | | | false | false |",
"+----+-----------------+--------------------------+--------------------------+--------------------------+",
];
assert_batches_sorted_eq!(&expected, &actual);
assert!(!survivors.contains(&"file0".to_string()));
assert!(survivors.contains(&"file1".to_string()));
assert!(survivors.contains(&"file2".to_string()));
assert_eq!(survivors.len(), 2);
}
#[tokio::test]
async fn test_barrier_changing_indices() {
let schema = get_schema();
let mut batches = vec![];
let keys = UInt16Array::from(vec![Some(0), Some(1)]);
let values = StringArray::from(vec![Some("file0"), Some("file1")]);
let dict = DictionaryArray::new(keys, Arc::new(values));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::StringArray::from(vec!["0", "1"])),
Arc::new(dict),
Arc::new(arrow::array::BooleanArray::from(vec![
Some(false),
Some(false),
])),
Arc::new(arrow::array::BooleanArray::from(vec![
Some(false),
Some(false),
])),
Arc::new(arrow::array::BooleanArray::from(vec![
Some(false),
Some(false),
])),
],
)
.unwrap();
batches.push(batch);
let keys = UInt16Array::from(vec![Some(0), Some(1)]);
let values = StringArray::from(vec![Some("file1"), Some("file0")]);
let dict = DictionaryArray::new(keys, Arc::new(values));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::StringArray::from(vec!["2", "3"])),
Arc::new(dict),
Arc::new(arrow::array::BooleanArray::from(vec![
Some(false),
Some(false),
])),
Arc::new(arrow::array::BooleanArray::from(vec![None, Some(false)])),
Arc::new(arrow::array::BooleanArray::from(vec![Some(false), None])),
],
)
.unwrap();
batches.push(batch);
let (actual, _survivors) = execute(batches).await;
let expected = vec![
"+----+-----------------+--------------------------+--------------------------+--------------------------+",
"| id | __delta_rs_path | __delta_rs_target_insert | __delta_rs_target_update | __delta_rs_target_delete |",
"+----+-----------------+--------------------------+--------------------------+--------------------------+",
"| 0 | file0 | false | false | false |",
"| 1 | file1 | false | false | false |",
"| 2 | file1 | false | | false |",
"| 3 | file0 | false | false | |",
"+----+-----------------+--------------------------+--------------------------+--------------------------+",
];
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_barrier_null_paths() {
let schema = get_schema();
let keys = UInt16Array::from(vec![Some(0), None, Some(1)]);
let values = StringArray::from(vec![Some("file1"), None]);
let dict = DictionaryArray::new(keys, Arc::new(values));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::StringArray::from(vec!["1", "2", "3"])),
Arc::new(dict),
Arc::new(arrow::array::BooleanArray::from(vec![
Some(false),
None,
None,
])),
Arc::new(arrow::array::BooleanArray::from(vec![false, false, false])),
Arc::new(arrow::array::BooleanArray::from(vec![false, false, false])),
],
)
.unwrap();
let (actual, _) = execute(vec![batch]).await;
let expected = vec![
"+----+-----------------+--------------------------+--------------------------+--------------------------+",
"| id | __delta_rs_path | __delta_rs_target_insert | __delta_rs_target_update | __delta_rs_target_delete |",
"+----+-----------------+--------------------------+--------------------------+--------------------------+",
"| 2 | | | false | false |",
"| 3 | | | false | false |",
"+----+-----------------+--------------------------+--------------------------+--------------------------+",
];
assert_batches_sorted_eq!(&expected, &actual);
}
async fn execute(input: Vec<RecordBatch>) -> (Vec<RecordBatch>, BarrierSurvivorSet) {
let schema = get_schema();
let repartition = Arc::new(Column::new("__delta_rs_path", 2));
let exec = MemorySourceConfig::try_new_exec(&[input], schema.clone(), None).unwrap();
let task_ctx = Arc::new(TaskContext::default());
let merge =
MergeBarrierExec::new(exec, Arc::new("__delta_rs_path".to_string()), repartition);
let survivors = merge.survivors();
let coalescence = CoalesceBatchesExec::new(Arc::new(merge), 100);
let mut stream = coalescence.execute(0, task_ctx).unwrap();
(vec![stream.next().await.unwrap().unwrap()], survivors)
}
fn get_schema() -> Arc<ArrowSchema> {
Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new(
"__delta_rs_path",
ArrowDataType::Dictionary(
Box::new(ArrowDataType::UInt16),
Box::new(ArrowDataType::Utf8),
),
true,
),
Field::new(TARGET_INSERT_COLUMN, ArrowDataType::Boolean, true),
Field::new(TARGET_UPDATE_COLUMN, ArrowDataType::Boolean, true),
Field::new(TARGET_DELETE_COLUMN, ArrowDataType::Boolean, true),
]))
}
}