use std::collections::HashMap;
use std::fmt::Debug;
use std::ops::{Deref, Not};
use std::sync::Arc;
use std::time::Instant;
use arrow_schema::{DataType, Field, SchemaBuilder};
use async_trait::async_trait;
use datafusion::catalog::Session;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::common::{
Column, DFSchema, DFSchemaRef, ExprSchema, ScalarValue, TableReference, plan_err,
};
use datafusion::datasource::provider_as_source;
use datafusion::error::Result as DataFusionResult;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::functions_window::expr_fn::row_number;
use datafusion::logical_expr::build_join_schema;
use datafusion::logical_expr::simplify::SimplifyContext;
use datafusion::logical_expr::utils::split_conjunction_owned;
use datafusion::logical_expr::{
Expr, ExprFunctionExt, JoinType, col, conditional_expressions::CaseBuilder, lit, when,
};
use datafusion::logical_expr::{
Extension, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE, UserDefinedLogicalNode,
};
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_plan::metrics::{MetricBuilder, MetricsSet};
use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
use datafusion::{
execution::context::SessionState,
physical_plan::ExecutionPlan,
prelude::{DataFrame, cast},
};
use delta_kernel::engine::arrow_conversion::{TryIntoArrow as _, TryIntoKernel as _};
use delta_kernel::schema::{ColumnMetadataKey, StructType};
use filter::try_construct_early_filter;
use futures::future::BoxFuture;
use parquet::file::properties::WriterProperties;
use serde::Serialize;
use tracing::*;
use uuid::Uuid;
use self::barrier::{MergeBarrier, MergeBarrierExec};
use self::validation::{MergeValidation, MergeValidationExec};
use super::{CustomExecuteHandler, Operation};
use crate::delta_datafusion::expr::fmt_expr_to_sql;
use crate::delta_datafusion::logical::MetricObserver;
use crate::delta_datafusion::physical::{MetricObserverExec, find_metric_node, get_metric};
use crate::delta_datafusion::planner::DeltaPlanner;
use crate::delta_datafusion::utils::coerce_predicate_literals;
use crate::delta_datafusion::{
DataFusionMixins, DeltaColumn, DeltaScanExec, DeltaScanNext, SessionFallbackPolicy,
SessionResolveContext, create_session, normalize_path_as_file_id, resolve_file_column_name,
resolve_session_state, update_datafusion_session,
};
use crate::delta_datafusion::{Expression, into_expr, maybe_into_expr};
use crate::kernel::schema::cast::{merge_arrow_field, merge_arrow_schema};
use crate::kernel::transaction::{CommitBuilder, CommitProperties, PROTOCOL};
use crate::kernel::{Action, EagerSnapshot, StructTypeExt, new_metadata, resolve_snapshot};
use crate::logstore::LogStoreRef;
use crate::operations::cdc::*;
use crate::operations::merge::barrier::find_node;
use crate::operations::write::WriterStatsConfig;
use crate::operations::write::execution::write_execution_plan_v2;
use crate::operations::write::generated_columns::{
add_generated_columns, add_missing_generated_columns, gc_is_enabled,
};
use crate::protocol::{DeltaOperation, MergePredicate};
use crate::table::config::TablePropertiesExt as _;
use crate::table::state::DeltaTableState;
use crate::{DeltaResult, DeltaTable, DeltaTableError};
mod barrier;
mod filter;
mod validation;
const SOURCE_COLUMN: &str = "__delta_rs_source";
const TARGET_COLUMN: &str = "__delta_rs_target";
const OPERATION_COLUMN: &str = "__delta_rs_operation";
const DELETE_COLUMN: &str = "__delta_rs_delete";
const TARGET_ROW_ORDINAL_IN_FILE_COLUMN: &str = "__delta_rs_target_row_ordinal_in_file";
const TARGET_MATCH_ROW_RANK_COLUMN: &str = "__delta_rs_target_match_row_rank";
pub(crate) const TARGET_INSERT_COLUMN: &str = "__delta_rs_target_insert";
pub(crate) const TARGET_UPDATE_COLUMN: &str = "__delta_rs_target_update";
pub(crate) const TARGET_DELETE_COLUMN: &str = "__delta_rs_target_delete";
pub(crate) const TARGET_COPY_COLUMN: &str = "__delta_rs_target_copy";
const TARGET_MATCH_CARDINALITY_CLASS_COLUMN: &str = "__delta_rs_match_cardinality_class";
const SOURCE_COUNT_METRIC: &str = "num_source_rows";
const TARGET_COUNT_METRIC: &str = "num_target_rows";
const TARGET_COPY_METRIC: &str = "num_copied_rows";
const TARGET_INSERTED_METRIC: &str = "num_target_inserted_rows";
const TARGET_UPDATED_METRIC: &str = "num_target_updated_rows";
const TARGET_DELETED_METRIC: &str = "num_target_deleted_rows";
const TARGET_FILES_SCANNED_METRIC: &str = "count_files_scanned";
const TARGET_FILES_SCANNED_METRIC_LEGACY: &str = "files_scanned";
const TARGET_FILES_PRUNED_METRIC: &str = "count_files_pruned";
const TARGET_FILES_SKIPPED_METRIC: &str = "count_files_skipped";
const TARGET_FILES_PRUNED_METRIC_LEGACY: &str = "files_pruned";
const SOURCE_COUNT_ID: &str = "merge_source_count";
const TARGET_COUNT_ID: &str = "merge_target_count";
const OUTPUT_COUNT_ID: &str = "merge_output_count";
pub struct MergeBuilder {
predicate: Expression,
match_operations: Vec<MergeOperationConfig>,
not_match_operations: Vec<MergeOperationConfig>,
not_match_source_operations: Vec<MergeOperationConfig>,
source_alias: Option<String>,
target_alias: Option<String>,
snapshot: Option<EagerSnapshot>,
source: DataFrame,
streaming: bool,
merge_schema: bool,
log_store: LogStoreRef,
state: Option<Arc<dyn Session>>,
session_fallback_policy: SessionFallbackPolicy,
writer_properties: Option<WriterProperties>,
commit_properties: CommitProperties,
safe_cast: bool,
custom_execute_handler: Option<Arc<dyn CustomExecuteHandler>>,
}
impl super::Operation for MergeBuilder {
fn log_store(&self) -> &LogStoreRef {
&self.log_store
}
fn get_custom_execute_handler(&self) -> Option<Arc<dyn CustomExecuteHandler>> {
self.custom_execute_handler.clone()
}
}
impl MergeBuilder {
pub fn new<E: Into<Expression>>(
log_store: LogStoreRef,
snapshot: Option<EagerSnapshot>,
predicate: E,
source: DataFrame,
) -> Self {
let predicate = predicate.into();
Self {
predicate,
source,
snapshot,
log_store,
source_alias: None,
target_alias: None,
state: None,
session_fallback_policy: SessionFallbackPolicy::default(),
commit_properties: CommitProperties::default(),
writer_properties: None,
merge_schema: false,
match_operations: Vec::new(),
not_match_operations: Vec::new(),
not_match_source_operations: Vec::new(),
safe_cast: false,
streaming: false,
custom_execute_handler: None,
}
}
pub fn when_matched_update<F>(mut self, builder: F) -> DeltaResult<MergeBuilder>
where
F: FnOnce(UpdateBuilder) -> UpdateBuilder,
{
let builder = builder(UpdateBuilder::default());
let op =
MergeOperationConfig::new(builder.predicate, builder.updates, OperationType::Update)?;
self.match_operations.push(op);
Ok(self)
}
pub fn when_matched_delete<F>(mut self, builder: F) -> DeltaResult<MergeBuilder>
where
F: FnOnce(DeleteBuilder) -> DeleteBuilder,
{
let builder = builder(DeleteBuilder::default());
let op = MergeOperationConfig::new(
builder.predicate,
HashMap::default(),
OperationType::Delete,
)?;
self.match_operations.push(op);
Ok(self)
}
pub fn when_not_matched_insert<F>(mut self, builder: F) -> DeltaResult<MergeBuilder>
where
F: FnOnce(InsertBuilder) -> InsertBuilder,
{
let builder = builder(InsertBuilder::default());
let op = MergeOperationConfig::new(builder.predicate, builder.set, OperationType::Insert)?;
self.not_match_operations.push(op);
Ok(self)
}
pub fn when_not_matched_by_source_update<F>(mut self, builder: F) -> DeltaResult<MergeBuilder>
where
F: FnOnce(UpdateBuilder) -> UpdateBuilder,
{
let builder = builder(UpdateBuilder::default());
let op =
MergeOperationConfig::new(builder.predicate, builder.updates, OperationType::Update)?;
self.not_match_source_operations.push(op);
Ok(self)
}
pub fn when_not_matched_by_source_delete<F>(mut self, builder: F) -> DeltaResult<MergeBuilder>
where
F: FnOnce(DeleteBuilder) -> DeleteBuilder,
{
let builder = builder(DeleteBuilder::default());
let op = MergeOperationConfig::new(
builder.predicate,
HashMap::default(),
OperationType::Delete,
)?;
self.not_match_source_operations.push(op);
Ok(self)
}
pub fn with_source_alias<S: ToString>(mut self, alias: S) -> Self {
self.source_alias = Some(alias.to_string());
self
}
pub fn with_target_alias<S: ToString>(mut self, alias: S) -> Self {
self.target_alias = Some(alias.to_string());
self
}
pub fn with_merge_schema(mut self, merge_schema: bool) -> Self {
self.merge_schema = merge_schema;
self
}
pub fn with_session_state(mut self, state: Arc<dyn Session>) -> Self {
self.state = Some(state);
self
}
pub fn with_session_fallback_policy(mut self, policy: SessionFallbackPolicy) -> Self {
self.session_fallback_policy = policy;
self
}
pub fn with_commit_properties(mut self, commit_properties: CommitProperties) -> Self {
self.commit_properties = commit_properties;
self
}
pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self {
self.writer_properties = Some(writer_properties);
self
}
pub fn with_safe_cast(mut self, safe_cast: bool) -> Self {
self.safe_cast = safe_cast;
self
}
pub fn with_streaming(mut self, streaming: bool) -> Self {
self.streaming = streaming;
self
}
pub fn with_custom_execute_handler(mut self, handler: Arc<dyn CustomExecuteHandler>) -> Self {
self.custom_execute_handler = Some(handler);
self
}
}
#[derive(Default)]
pub struct UpdateBuilder {
predicate: Option<Expression>,
updates: HashMap<Column, Expression>,
}
impl UpdateBuilder {
pub fn predicate<E: Into<Expression>>(mut self, predicate: E) -> Self {
self.predicate = Some(predicate.into());
self
}
pub fn update<C: Into<DeltaColumn>, E: Into<Expression>>(
mut self,
column: C,
expression: E,
) -> Self {
self.updates.insert(column.into().into(), expression.into());
self
}
}
#[derive(Default)]
pub struct InsertBuilder {
predicate: Option<Expression>,
set: HashMap<Column, Expression>,
}
impl InsertBuilder {
pub fn predicate<E: Into<Expression>>(mut self, predicate: E) -> Self {
self.predicate = Some(predicate.into());
self
}
pub fn set<C: Into<DeltaColumn>, E: Into<Expression>>(
mut self,
column: C,
expression: E,
) -> Self {
self.set.insert(column.into().into(), expression.into());
self
}
}
#[derive(Default)]
pub struct DeleteBuilder {
predicate: Option<Expression>,
}
impl DeleteBuilder {
pub fn predicate<E: Into<Expression>>(mut self, predicate: E) -> Self {
self.predicate = Some(predicate.into());
self
}
}
#[derive(Debug, Copy, Clone)]
enum OperationType {
Update,
Delete,
SourceDelete,
Insert,
Copy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
enum MatchParticipationClass {
Ignore = 0,
MatchedNoop = 1,
MatchedUnconditionalDelete = 2,
MatchedAction = 3,
}
struct MergeOperationConfig {
predicate: Option<Expression>,
operations: HashMap<Column, Expression>,
r#type: OperationType,
}
struct MergeOperation {
predicate: Option<Expr>,
operations: HashMap<Column, Expr>,
r#type: OperationType,
match_participation_class: MatchParticipationClass,
}
impl MergeOperation {
fn try_from(
config: MergeOperationConfig,
schema: &DFSchema,
state: &dyn Session,
target_alias: &Option<String>,
) -> DeltaResult<MergeOperation> {
let mut ops = HashMap::with_capacity(config.operations.capacity());
for (column, expression) in config.operations.into_iter() {
let column = match target_alias {
Some(alias) => {
let r = TableReference::bare(alias.to_owned());
match column {
Column {
relation: None,
name,
spans,
} => Column {
relation: Some(r),
name,
spans,
},
Column {
relation: Some(TableReference::Bare { table }),
name,
spans,
} => {
if table.as_ref() == alias {
Column {
relation: Some(r),
name,
spans,
}
} else {
return Err(DeltaTableError::Generic(format!(
"Table alias '{table}' in column reference '{table}.{name}' unknown. Hint: You must reference the Delta Table with alias '{alias}'."
)));
}
}
_ => {
return Err(DeltaTableError::Generic(
"Column must reference column in Delta table".into(),
));
}
}
}
None => column,
};
ops.insert(column, into_expr(expression, schema, state)?);
}
Ok(MergeOperation {
predicate: maybe_into_expr(config.predicate, schema, state)?,
operations: ops,
r#type: config.r#type,
match_participation_class: MatchParticipationClass::Ignore,
})
}
fn into_matched(mut self) -> Self {
self.match_participation_class = match self.r#type {
OperationType::Delete if self.predicate.is_none() => {
MatchParticipationClass::MatchedUnconditionalDelete
}
OperationType::Delete | OperationType::Update => MatchParticipationClass::MatchedAction,
OperationType::Copy => MatchParticipationClass::MatchedNoop,
OperationType::Insert | OperationType::SourceDelete => MatchParticipationClass::Ignore,
};
self
}
}
impl MergeOperationConfig {
pub fn new(
predicate: Option<Expression>,
operations: HashMap<Column, Expression>,
r#type: OperationType,
) -> DeltaResult<Self> {
Ok(MergeOperationConfig {
predicate,
operations,
r#type,
})
}
}
#[derive(Default, Serialize, Debug)]
pub struct MergeMetrics {
pub num_source_rows: usize,
pub num_target_rows_inserted: usize,
pub num_target_rows_updated: usize,
pub num_target_rows_deleted: usize,
pub num_target_rows_copied: usize,
pub num_output_rows: usize,
pub num_target_files_scanned: usize,
pub num_target_files_skipped_during_scan: usize,
pub num_target_files_added: usize,
pub num_target_files_removed: usize,
pub execution_time_ms: u64,
pub scan_time_ms: u64,
pub rewrite_time_ms: u64,
}
#[derive(Clone, Debug)]
pub(crate) struct MergeMetricExtensionPlanner {}
impl MergeMetricExtensionPlanner {
pub fn new() -> Arc<Self> {
Arc::new(Self {})
}
}
#[async_trait]
impl ExtensionPlanner for MergeMetricExtensionPlanner {
async fn plan_extension(
&self,
planner: &dyn PhysicalPlanner,
node: &dyn UserDefinedLogicalNode,
_logical_inputs: &[&LogicalPlan],
physical_inputs: &[Arc<dyn ExecutionPlan>],
session_state: &SessionState,
) -> DataFusionResult<Option<Arc<dyn ExecutionPlan>>> {
if let Some(metric_observer) = node.as_any().downcast_ref::<MetricObserver>() {
if metric_observer.id.eq(SOURCE_COUNT_ID) {
return Ok(Some(MetricObserverExec::try_new(
SOURCE_COUNT_ID.into(),
physical_inputs,
|batch, metrics| {
MetricBuilder::new(metrics)
.global_counter(SOURCE_COUNT_METRIC)
.add(batch.num_rows());
},
)?));
}
if metric_observer.id.eq(TARGET_COUNT_ID) {
return Ok(Some(MetricObserverExec::try_new(
TARGET_COUNT_ID.into(),
physical_inputs,
|batch, metrics| {
MetricBuilder::new(metrics)
.global_counter(TARGET_COUNT_METRIC)
.add(batch.num_rows());
},
)?));
}
if metric_observer.id.eq(OUTPUT_COUNT_ID) {
return Ok(Some(MetricObserverExec::try_new(
OUTPUT_COUNT_ID.into(),
physical_inputs,
|batch, metrics| {
MetricBuilder::new(metrics)
.global_counter(TARGET_INSERTED_METRIC)
.add(
batch
.column_by_name(TARGET_INSERT_COLUMN)
.unwrap()
.null_count(),
);
MetricBuilder::new(metrics)
.global_counter(TARGET_UPDATED_METRIC)
.add(
batch
.column_by_name(TARGET_UPDATE_COLUMN)
.unwrap()
.null_count(),
);
MetricBuilder::new(metrics)
.global_counter(TARGET_DELETED_METRIC)
.add(
batch
.column_by_name(TARGET_DELETE_COLUMN)
.unwrap()
.null_count(),
);
MetricBuilder::new(metrics)
.global_counter(TARGET_COPY_METRIC)
.add(
batch
.column_by_name(TARGET_COPY_COLUMN)
.unwrap()
.null_count(),
);
},
)?));
}
}
if let Some(validation) = node.as_any().downcast_ref::<MergeValidation>() {
if physical_inputs.len() != 1 {
return plan_err!("MergeValidationExec expects exactly one input");
}
let schema = validation.input.schema();
return Ok(Some(Arc::new(MergeValidationExec::new(
physical_inputs.first().unwrap().clone(),
planner.create_physical_expr(&validation.file_expr, schema, session_state)?,
Arc::clone(&validation.file_column),
Arc::clone(&validation.row_ordinal_column),
))));
}
if let Some(barrier) = node.as_any().downcast_ref::<MergeBarrier>() {
if physical_inputs.len() != 1 {
return plan_err!("MergeBarrierExec expects exactly one input");
}
let schema = barrier.input.schema();
return Ok(Some(Arc::new(MergeBarrierExec::new(
physical_inputs.first().unwrap().clone(),
barrier.file_column.clone(),
planner.create_physical_expr(&barrier.expr, schema, session_state)?,
))));
}
Ok(None)
}
}
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(operation = "merge", version = snapshot.version(), table_uri = %log_store.root_url()))]
async fn execute(
predicate: Expression,
mut source: DataFrame,
log_store: LogStoreRef,
snapshot: EagerSnapshot,
state: SessionState,
writer_properties: Option<WriterProperties>,
mut commit_properties: CommitProperties,
_safe_cast: bool,
streaming: bool,
source_alias: Option<String>,
target_alias: Option<String>,
merge_schema: bool,
match_operations: Vec<MergeOperationConfig>,
not_match_target_operations: Vec<MergeOperationConfig>,
not_match_source_operations: Vec<MergeOperationConfig>,
operation_id: Uuid,
handle: Option<&Arc<dyn CustomExecuteHandler>>,
) -> DeltaResult<(EagerSnapshot, MergeMetrics)> {
info!(
operation = "merge",
version = snapshot.version(),
"starting merge execution"
);
let mut metrics = MergeMetrics::default();
let exec_start = Instant::now();
let should_cdc = should_write_cdc(&snapshot)?;
if should_cdc {
debug!("Executing a merge and I should write CDC!");
}
info!(cdc_enabled = should_cdc, "merge execution details");
let current_metadata = snapshot.metadata();
let merge_planner = DeltaPlanner::new();
let state = SessionStateBuilder::new_from_existing(state)
.with_query_planner(merge_planner)
.build();
let source_name = match &source_alias {
Some(alias) => TableReference::bare(alias.to_string()),
None => TableReference::bare(UNNAMED_TABLE),
};
let target_name = match &target_alias {
Some(alias) => TableReference::bare(alias.to_string()),
None => TableReference::bare(UNNAMED_TABLE),
};
let mut generated_col_exp = None;
let mut missing_generated_col = None;
if gc_is_enabled(&snapshot) {
let generated_col_expressions = snapshot.schema().get_generated_columns()?;
let (source_with_gc, missing_generated_columns) =
add_missing_generated_columns(source, &generated_col_expressions)?;
source = source_with_gc;
generated_col_exp = Some(generated_col_expressions);
missing_generated_col = Some(missing_generated_columns);
}
let source = LogicalPlanBuilder::scan(
source_name.clone(),
provider_as_source(source.into_view()),
None,
)?
.build()?;
let source = LogicalPlan::Extension(Extension {
node: Arc::new(MetricObserver {
id: SOURCE_COUNT_ID.into(),
input: source,
enable_pushdown: false,
}),
});
let file_column = Arc::new(resolve_file_column_name(
snapshot.input_schema().as_ref(),
None,
)?);
let target_provider = provider_as_source(
DeltaScanNext::builder()
.with_eager_snapshot(snapshot.clone())
.with_log_store(log_store.clone())
.with_file_column(file_column.as_str())
.await?,
);
let target =
LogicalPlanBuilder::scan(target_name.clone(), target_provider.clone(), None)?.build()?;
let source_schema = source.schema();
let target_schema = target.schema();
let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?;
let predicate = predicate.resolve(&state, Arc::new(join_schema_df.clone()))?;
let target_subset_filter: Option<Expr> = if !not_match_source_operations.is_empty() {
None
} else {
try_construct_early_filter(
predicate.clone(),
&snapshot,
&state,
&source,
&source_name,
&target_name,
streaming,
)
.await?
}
.map(|e| normalize_target_subset_filter(target.schema().clone(), e))
.transpose()?;
let commit_predicate = match target_subset_filter.clone() {
None => None, Some(some_filter) => {
let predict_expr = match &target_alias {
None => some_filter,
Some(alias) => remove_table_alias(some_filter, alias),
};
Some(fmt_expr_to_sql(&predict_expr)?)
}
};
debug!("Using target subset filter: {commit_predicate:?}");
let file_skipping_predicates =
build_file_skipping_predicates(target_subset_filter, target_alias.as_deref());
let target_provider = {
let mut builder = DeltaScanNext::builder()
.with_eager_snapshot(snapshot.clone())
.with_log_store(log_store.clone())
.with_file_column(file_column.as_str());
if !file_skipping_predicates.is_empty() {
builder = builder.with_file_skipping_predicates(file_skipping_predicates);
}
provider_as_source(builder.await?)
};
let target = LogicalPlanBuilder::scan(target_name.clone(), target_provider, None)?.build()?;
let source = DataFrame::new(state.clone(), source.clone());
let source = source.with_column(SOURCE_COLUMN, lit(true))?;
let enable_pushdown =
not_match_source_operations.is_empty() && not_match_target_operations.is_empty();
let target = LogicalPlan::Extension(Extension {
node: Arc::new(MetricObserver {
id: TARGET_COUNT_ID.into(),
input: target,
enable_pushdown,
}),
});
let target = DataFrame::new(state.clone(), target);
let target = target.with_column(
TARGET_ROW_ORDINAL_IN_FILE_COLUMN,
row_number()
.partition_by(vec![col(file_column.as_str())])
.build()?,
)?;
let target = target.with_column(TARGET_COLUMN, lit(true))?;
let join = source.join(target, JoinType::Full, &[], &[], Some(predicate.clone()))?;
let join_schema_df = join.schema().to_owned();
let match_operations: Vec<MergeOperation> = match_operations
.into_iter()
.map(|op| {
MergeOperation::try_from(op, &join_schema_df, &state, &target_alias)
.map(MergeOperation::into_matched)
})
.collect::<Result<Vec<MergeOperation>, DeltaTableError>>()?;
let not_match_target_operations: Vec<MergeOperation> = not_match_target_operations
.into_iter()
.map(|op| MergeOperation::try_from(op, &join_schema_df, &state, &target_alias))
.collect::<Result<Vec<MergeOperation>, DeltaTableError>>()?;
let not_match_source_operations: Vec<MergeOperation> = not_match_source_operations
.into_iter()
.map(|op| MergeOperation::try_from(op, &join_schema_df, &state, &target_alias))
.collect::<Result<Vec<MergeOperation>, DeltaTableError>>()?;
let mut new_schema = None;
let mut schema_action = None;
if merge_schema {
let logical_schema = snapshot.input_schema();
let logical_target_schema =
DFSchema::try_from_qualified_schema(target_name.clone(), logical_schema.as_ref())?;
let merge_schema =
merge_arrow_schema(logical_schema, source_schema.inner().clone(), false)?;
let mut schema_builder = SchemaBuilder::from(merge_schema.deref());
modify_schema(
&mut schema_builder,
&logical_target_schema,
source_schema,
&match_operations,
)?;
modify_schema(
&mut schema_builder,
&logical_target_schema,
source_schema,
¬_match_source_operations,
)?;
modify_schema(
&mut schema_builder,
&logical_target_schema,
source_schema,
¬_match_target_operations,
)?;
let schema = Arc::new(schema_builder.finish());
new_schema = Some(schema.clone());
let schema_struct: StructType = schema.try_into_kernel()?;
if &schema_struct != snapshot.schema().as_ref() {
let action = Action::Metadata(new_metadata(
&schema_struct,
current_metadata.partition_columns(),
snapshot.metadata().configuration(),
)?);
schema_action = Some(action);
}
}
let matched = col(SOURCE_COLUMN)
.is_true()
.and(col(TARGET_COLUMN).is_true());
let not_matched_target = col(SOURCE_COLUMN)
.is_true()
.and(col(TARGET_COLUMN).is_null());
let not_matched_source = col(SOURCE_COLUMN)
.is_null()
.and(col(TARGET_COLUMN))
.is_true();
let operations_size = match_operations.len()
+ not_match_source_operations.len()
+ not_match_target_operations.len()
+ 3;
let mut when_expr = Vec::with_capacity(operations_size);
let mut then_expr = Vec::with_capacity(operations_size);
let mut ops: Vec<(
HashMap<Column, Expr>,
OperationType,
MatchParticipationClass,
)> = Vec::with_capacity(operations_size);
fn update_case(
operations: Vec<MergeOperation>,
ops: &mut Vec<(
HashMap<Column, Expr>,
OperationType,
MatchParticipationClass,
)>,
when_expr: &mut Vec<Expr>,
then_expr: &mut Vec<Expr>,
base_expr: &Expr,
) -> DeltaResult<Vec<MergePredicate>> {
let mut predicates = Vec::with_capacity(operations.len());
for op in operations {
let predicate = match &op.predicate {
Some(predicate) => base_expr.clone().and(predicate.to_owned()),
None => base_expr.clone(),
};
when_expr.push(predicate);
then_expr.push(lit(ops.len() as i32));
ops.push((op.operations, op.r#type, op.match_participation_class));
let action_type = match op.r#type {
OperationType::Update => "update",
OperationType::Delete => "delete",
OperationType::Insert => "insert",
OperationType::SourceDelete => {
return Err(DeltaTableError::Generic("Invalid action type".to_string()));
}
OperationType::Copy => {
return Err(DeltaTableError::Generic("Invalid action type".to_string()));
}
};
let action_type = action_type.to_string();
let predicate = op
.predicate
.map(|expr| fmt_expr_to_sql(&expr))
.transpose()?;
predicates.push(MergePredicate {
action_type,
predicate,
});
}
Ok(predicates)
}
let match_operations = update_case(
match_operations,
&mut ops,
&mut when_expr,
&mut then_expr,
&matched,
)?;
let not_match_target_operations = update_case(
not_match_target_operations,
&mut ops,
&mut when_expr,
&mut then_expr,
¬_matched_target,
)?;
let not_match_source_operations = update_case(
not_match_source_operations,
&mut ops,
&mut when_expr,
&mut then_expr,
¬_matched_source,
)?;
when_expr.push(matched.clone());
then_expr.push(lit(ops.len() as i32));
ops.push((
HashMap::new(),
OperationType::Copy,
MatchParticipationClass::MatchedNoop,
));
when_expr.push(not_matched_target);
then_expr.push(lit(ops.len() as i32));
ops.push((
HashMap::new(),
OperationType::SourceDelete,
MatchParticipationClass::Ignore,
));
when_expr.push(not_matched_source);
then_expr.push(lit(ops.len() as i32));
ops.push((
HashMap::new(),
OperationType::Copy,
MatchParticipationClass::Ignore,
));
let case = CaseBuilder::new(None, when_expr, then_expr, None).end()?;
let projection = join.with_column(OPERATION_COLUMN, case)?;
let mut new_columns = vec![];
let mut write_projection = Vec::new();
let mut write_projection_with_cdf = Vec::new();
let schema = if let Some(schema) = new_schema {
Arc::new(schema.try_into_kernel()?)
} else {
snapshot.schema()
};
for delta_field in schema.fields() {
let mut when_expr = Vec::with_capacity(operations_size);
let mut then_expr = Vec::with_capacity(operations_size);
let qualifier = match &target_alias {
Some(alias) => Some(TableReference::Bare {
table: alias.to_owned().into(),
}),
None => TableReference::none(),
};
let mut null_target_column = None;
let source_qualifier = match &source_alias {
Some(alias) => Some(TableReference::Bare {
table: alias.to_owned().into(),
}),
None => TableReference::none(),
};
let name = delta_field.name();
let mut cast_type: DataType = delta_field.data_type().try_into_arrow()?;
let column = if let Some(field) = snapshot.schema().field(name) {
if field == delta_field {
Column::new(qualifier.clone(), name)
} else {
let col_ref = Column::new(source_qualifier.clone(), name);
cast_type = source_schema.data_type(&col_ref)?.to_owned();
col_ref
}
} else {
null_target_column = Some(cast(
lit(ScalarValue::Null).alias(name),
delta_field.data_type().try_into_arrow()?,
));
Column::new(source_qualifier.clone(), name)
};
for (idx, (operations, _, _)) in ops.iter().enumerate() {
let op: Expr = operations
.get(&column)
.map(|expr| expr.to_owned())
.unwrap_or_else(|| col(column.clone()));
when_expr.push(lit(idx as i32));
then_expr.push(op);
}
let case = CaseBuilder::new(
Some(Box::new(col(OPERATION_COLUMN))),
when_expr,
then_expr,
None,
)
.end()?;
let name = "__delta_rs_c_".to_owned() + delta_field.name();
write_projection.push(cast(
Expr::Column(Column::from_name(name.clone())).alias(delta_field.name()),
cast_type.clone(),
));
write_projection_with_cdf.push(
when(
col(CDC_COLUMN_NAME).not_eq(lit("update_preimage")),
cast(
Expr::Column(Column::from_name(name.clone())),
cast_type.clone(),
),
)
.otherwise(null_target_column.unwrap_or(cast(
Expr::Column(Column::new(qualifier, delta_field.name())),
cast_type,
)))? .alias(delta_field.name()),
);
new_columns.push((name, case));
}
write_projection_with_cdf.push(col("_change_type"));
let mut insert_when = Vec::with_capacity(ops.len());
let mut insert_then = Vec::with_capacity(ops.len());
let mut update_when = Vec::with_capacity(ops.len());
let mut update_then = Vec::with_capacity(ops.len());
let mut target_delete_when = Vec::with_capacity(ops.len());
let mut target_delete_then = Vec::with_capacity(ops.len());
let mut delete_when = Vec::with_capacity(ops.len());
let mut delete_then = Vec::with_capacity(ops.len());
let mut copy_when = Vec::with_capacity(ops.len());
let mut copy_then = Vec::with_capacity(ops.len());
for (idx, (_operations, r#type, _)) in ops.iter().enumerate() {
let op = idx as i32;
delete_when.push(lit(op));
delete_then.push(lit(matches!(
r#type,
OperationType::Delete | OperationType::SourceDelete
)));
insert_when.push(lit(op));
insert_then.push(
when(
lit(matches!(r#type, OperationType::Insert)),
lit(ScalarValue::Boolean(None)),
)
.otherwise(lit(false))?,
);
update_when.push(lit(op));
update_then.push(
when(
lit(matches!(r#type, OperationType::Update)),
lit(ScalarValue::Boolean(None)),
)
.otherwise(lit(false))?,
);
target_delete_when.push(lit(op));
target_delete_then.push(
when(
lit(matches!(r#type, OperationType::Delete)),
lit(ScalarValue::Boolean(None)),
)
.otherwise(lit(false))?,
);
copy_when.push(lit(op));
copy_then.push(
when(
lit(matches!(r#type, OperationType::Copy)),
lit(ScalarValue::Boolean(None)),
)
.otherwise(lit(false))?,
);
}
fn build_case(when: Vec<Expr>, then: Vec<Expr>) -> DataFusionResult<Expr> {
CaseBuilder::new(
Some(Box::new(col(OPERATION_COLUMN))),
when,
then,
Some(Box::new(lit(false))),
)
.end()
}
new_columns.push((
DELETE_COLUMN.to_owned(),
build_case(delete_when, delete_then)?,
));
new_columns.push((
TARGET_INSERT_COLUMN.to_owned(),
build_case(insert_when, insert_then)?,
));
new_columns.push((
TARGET_UPDATE_COLUMN.to_owned(),
build_case(update_when, update_then)?,
));
new_columns.push((
TARGET_DELETE_COLUMN.to_owned(),
build_case(target_delete_when, target_delete_then)?,
));
new_columns.push((
TARGET_COPY_COLUMN.to_owned(),
build_case(copy_when, copy_then)?,
));
let new_columns = {
let plan = projection.into_unoptimized_plan();
let mut fields: Vec<Expr> = plan
.schema()
.columns()
.iter()
.map(|f| col(f.clone()))
.collect();
fields.extend(new_columns.into_iter().map(|(name, ex)| ex.alias(name)));
LogicalPlanBuilder::from(plan).project(fields)?.build()?
};
let new_columns = if !match_operations.is_empty() {
let mut cardinality_when = Vec::with_capacity(ops.len());
let mut cardinality_then = Vec::with_capacity(ops.len());
for (idx, (_, _, cardinality_class)) in ops.iter().enumerate() {
cardinality_when.push(lit(idx as i32));
cardinality_then.push(lit(*cardinality_class as i32));
}
let cardinality_class = CaseBuilder::new(
Some(Box::new(col(OPERATION_COLUMN))),
cardinality_when,
cardinality_then,
Some(Box::new(lit(0))),
)
.end()?;
let match_row_rank = row_number()
.partition_by(vec![
col(file_column.as_str()),
col(TARGET_ROW_ORDINAL_IN_FILE_COLUMN),
])
.order_by(vec![
col(TARGET_MATCH_CARDINALITY_CLASS_COLUMN).sort(false, false),
])
.build()?;
let new_columns = DataFrame::new(state.clone(), new_columns)
.with_column(TARGET_MATCH_CARDINALITY_CLASS_COLUMN, cardinality_class)?
.with_column(TARGET_MATCH_ROW_RANK_COLUMN, match_row_rank)?
.into_unoptimized_plan();
let validated = LogicalPlan::Extension(Extension {
node: Arc::new(MergeValidation {
input: new_columns,
file_expr: col(file_column.as_str()),
file_column: Arc::clone(&file_column),
row_ordinal_column: Arc::new(TARGET_ROW_ORDINAL_IN_FILE_COLUMN.to_string()),
}),
});
DataFrame::new(state.clone(), validated)
.filter(
matched
.and(col(TARGET_MATCH_ROW_RANK_COLUMN).gt(lit(1_u64)))
.not(),
)?
.into_unoptimized_plan()
} else {
new_columns
};
let distribute_expr = col(file_column.as_str());
let merge_barrier = LogicalPlan::Extension(Extension {
node: Arc::new(MergeBarrier {
input: new_columns.clone(),
expr: distribute_expr,
file_column: Arc::clone(&file_column),
}),
});
let operation_count = LogicalPlan::Extension(Extension {
node: Arc::new(MetricObserver {
id: OUTPUT_COUNT_ID.into(),
input: merge_barrier,
enable_pushdown: false,
}),
});
let operation_count = DataFrame::new(state.clone(), operation_count);
let mut projected = if should_cdc {
operation_count
.clone()
.with_column(
CDC_COLUMN_NAME,
when(col(TARGET_DELETE_COLUMN).is_null(), lit("delete")) .when(col(DELETE_COLUMN).is_null(), lit("source_delete"))
.when(col(TARGET_COPY_COLUMN).is_null(), lit("copy"))
.when(col(TARGET_INSERT_COLUMN).is_null(), lit("insert"))
.when(col(TARGET_UPDATE_COLUMN).is_null(), lit("update"))
.end()?,
)?
.drop_columns(&[file_column.as_str()])? .with_column(
"__delta_rs_update_expanded",
when(
col(CDC_COLUMN_NAME).eq(lit("update")),
lit(ScalarValue::List(ScalarValue::new_list(
&[
ScalarValue::Utf8(Some("update_preimage".into())),
ScalarValue::Utf8(Some("update_postimage".into())),
],
&DataType::List(Field::new("element", DataType::Utf8, false).into()),
true,
))),
)
.end()?,
)?
.unnest_columns(&["__delta_rs_update_expanded"])?
.with_column(
CDC_COLUMN_NAME,
when(
col(CDC_COLUMN_NAME).eq(lit("update")),
col("__delta_rs_update_expanded"),
)
.otherwise(col(CDC_COLUMN_NAME))?,
)?
.drop_columns(&["__delta_rs_update_expanded"])?
.select(write_projection_with_cdf)?
} else {
operation_count
.filter(col(DELETE_COLUMN).is_false())?
.select(write_projection)?
};
if let Some(generated_col_expressions) = generated_col_exp
&& let Some(missing_generated_columns) = missing_generated_col
{
projected = add_generated_columns(
projected,
&generated_col_expressions,
&missing_generated_columns,
&state,
)?;
}
let merge_final = &projected.into_unoptimized_plan();
let write = state.create_physical_plan(merge_final).await?;
let err = || DeltaTableError::Generic("Unable to locate expected metric node".into());
let source_count = find_metric_node(SOURCE_COUNT_ID, &write).ok_or_else(err)?;
let target_count = find_metric_node(TARGET_COUNT_ID, &write).ok_or_else(err)?;
let op_count = find_metric_node(OUTPUT_COUNT_ID, &write).ok_or_else(err)?;
let barrier = find_node::<MergeBarrierExec>(&write).ok_or_else(err)?;
let scan_count = find_node::<DeltaScanExec>(&target_count)
.or_else(|| find_node::<DeltaScanExec>(&write))
.ok_or_else(err)?;
let table_partition_cols = current_metadata.partition_columns().to_vec();
let writer_stats_config = WriterStatsConfig::from_config(snapshot.table_configuration());
let (mut actions, write_plan_metrics) = write_execution_plan_v2(
Some(&snapshot),
&state,
write,
table_partition_cols.to_vec(),
log_store.object_store(Some(operation_id)),
Some(snapshot.table_properties().target_file_size()),
None,
writer_properties.clone(),
writer_stats_config.clone(),
None,
should_cdc, None,
)
.await?;
if let Some(schema_metadata) = schema_action {
actions.push(schema_metadata);
}
metrics.rewrite_time_ms = write_plan_metrics.write_time_ms;
metrics.scan_time_ms = write_plan_metrics.scan_time_ms;
metrics.num_target_files_added = actions.len();
let survivors = barrier
.as_any()
.downcast_ref::<MergeBarrierExec>()
.unwrap()
.survivors();
let table_root = snapshot.table_configuration().table_root().clone();
for action in snapshot.log_data() {
let log_path = action.path_raw();
if should_remove_rewritten_file(&survivors, log_path, &table_root)? {
metrics.num_target_files_removed += 1;
actions.push(action.remove_action(true).into());
}
}
let source_count_metrics = source_count.metrics().unwrap();
let target_count_metrics = op_count.metrics().unwrap();
let scan_count_metrics = scan_count.metrics().unwrap();
metrics.num_source_rows = get_metric(&source_count_metrics, SOURCE_COUNT_METRIC);
metrics.num_target_rows_inserted = get_metric(&target_count_metrics, TARGET_INSERTED_METRIC);
metrics.num_target_rows_updated = get_metric(&target_count_metrics, TARGET_UPDATED_METRIC);
metrics.num_target_rows_deleted = get_metric(&target_count_metrics, TARGET_DELETED_METRIC);
metrics.num_target_rows_copied = get_metric(&target_count_metrics, TARGET_COPY_METRIC);
metrics.num_output_rows = metrics.num_target_rows_inserted
+ metrics.num_target_rows_updated
+ metrics.num_target_rows_copied;
let target_files_scanned_metric_names = [
TARGET_FILES_SCANNED_METRIC,
TARGET_FILES_SCANNED_METRIC_LEGACY,
];
metrics.num_target_files_scanned = get_metric_any_or(
&scan_count_metrics,
&target_files_scanned_metric_names,
|| {
warn!(
%operation_id,
metric_names = ?target_files_scanned_metric_names,
"Missing target scan metric; defaulting target files scanned to zero"
);
0
},
);
let target_files_skipped_metric_names = [
TARGET_FILES_PRUNED_METRIC,
TARGET_FILES_SKIPPED_METRIC,
TARGET_FILES_PRUNED_METRIC_LEGACY,
];
metrics.num_target_files_skipped_during_scan = get_metric_any_or(
&scan_count_metrics,
&target_files_skipped_metric_names,
|| {
let total_files = snapshot.log_data().num_files();
let (derived, impossible_state) =
derive_skipped_file_count(total_files, metrics.num_target_files_scanned);
if impossible_state {
warn!(
%operation_id,
total_files,
scanned_files = metrics.num_target_files_scanned,
metric_names = ?target_files_skipped_metric_names,
"Target scan metrics reported more scanned files than exist; clamping derived skipped-file count to zero"
);
}
warn!(
%operation_id,
metric_names = ?target_files_skipped_metric_names,
derived,
"Missing target skipped-file metric; deriving from total-files minus scanned-files"
);
derived
},
);
metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_millis() as u64;
let app_metadata = &mut commit_properties.app_metadata;
app_metadata.insert("readVersion".to_owned(), snapshot.version().into());
if let Ok(map) = serde_json::to_value(&metrics) {
app_metadata.insert("operationMetrics".to_owned(), map);
}
let operation = DeltaOperation::Merge {
predicate: commit_predicate,
merge_predicate: Some(fmt_expr_to_sql(&predicate)?),
matched_predicates: match_operations,
not_matched_predicates: not_match_target_operations,
not_matched_by_source_predicates: not_match_source_operations,
};
if actions.is_empty() {
return Ok((snapshot, metrics));
}
let commit = CommitBuilder::from(commit_properties)
.with_actions(actions)
.with_operation_id(operation_id)
.with_post_commit_hook_handler(handle.cloned())
.build(Some(&snapshot), log_store.clone(), operation)
.await?;
Ok((commit.snapshot().snapshot, metrics))
}
fn modify_schema(
ending_schema: &mut SchemaBuilder,
target_schema: &DFSchema,
source_schema: &DFSchema,
operations: &[MergeOperation],
) -> DeltaResult<()> {
for columns in operations
.iter()
.filter(|ops| matches!(ops.r#type, OperationType::Update | OperationType::Insert))
.flat_map(|ops| ops.operations.keys())
{
let source_field = source_schema.field_with_unqualified_name(columns.name())?;
if source_field
.metadata()
.contains_key(ColumnMetadataKey::GenerationExpression.as_ref())
{
let error = arrow::error::ArrowError::SchemaError("Schema evolved fields cannot have generated expressions. Recreate the table to achieve this.".to_string());
return Err(DeltaTableError::Arrow { source: error });
}
match target_schema.field_from_column(columns) {
Ok(target_field) => {
let new_field = merge_arrow_field(target_field, source_field, true)?;
if new_field != **target_field {
ending_schema.try_merge(&Arc::new(new_field))?;
}
}
Err(_) => {
ending_schema
.try_merge(&Arc::new(source_field.as_ref().clone().with_nullable(true)))?;
}
}
}
Ok(())
}
fn remove_table_alias(expr: Expr, table_alias: &str) -> Expr {
expr.transform(&|expr| match expr {
Expr::Column(c) => match c.relation {
Some(rel) if rel.table() == table_alias => Ok(Transformed::yes(Expr::Column(
Column::new_unqualified(c.name),
))),
_ => Ok(Transformed::no(Expr::Column(Column::new(
c.relation, c.name,
)))),
},
_ => Ok(Transformed::no(expr)),
})
.unwrap()
.data
}
fn normalize_target_subset_filter(target_schema: DFSchemaRef, expr: Expr) -> DeltaResult<Expr> {
let expr = coerce_predicate_literals(expr, target_schema.as_ref())?;
let simplify_context = SimplifyContext::default().with_schema(target_schema);
let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10);
Ok(simplifier.simplify(expr)?)
}
fn build_file_skipping_predicates(
target_subset_filter: Option<Expr>,
target_alias: Option<&str>,
) -> Vec<Expr> {
let Some(filter) = target_subset_filter else {
return Vec::new();
};
let filter = match target_alias {
Some(alias) => remove_table_alias(filter, alias),
None => filter,
};
split_conjunction_owned(filter)
}
fn derive_skipped_file_count(total_files: usize, scanned_files: usize) -> (usize, bool) {
let impossible_state = scanned_files > total_files;
(total_files.saturating_sub(scanned_files), impossible_state)
}
fn get_metric_any(metrics: &MetricsSet, names: &[&str]) -> Option<usize> {
names
.iter()
.find_map(|name| metrics.sum_by_name(name).map(|metric| metric.as_usize()))
}
fn get_metric_any_or(
metrics: &MetricsSet,
names: &[&str],
fallback: impl FnOnce() -> usize,
) -> usize {
get_metric_any(metrics, names).unwrap_or_else(fallback)
}
fn should_remove_rewritten_file(
survivors: &barrier::BarrierSurvivorSet,
log_path: &str,
table_root: &url::Url,
) -> DeltaResult<bool> {
if survivors.contains(log_path) {
return Ok(true);
}
let full_id = normalize_path_as_file_id(log_path, table_root, "merge remove")?;
Ok(survivors.contains(full_id.as_str()))
}
impl std::future::IntoFuture for MergeBuilder {
type Output = DeltaResult<(DeltaTable, MergeMetrics)>;
type IntoFuture = BoxFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
let this = self;
Box::pin(async move {
let snapshot =
resolve_snapshot(&this.log_store, this.snapshot.clone(), true, None).await?;
PROTOCOL.can_write_to(&snapshot)?;
let operation_id = this.get_operation_id();
this.pre_execute(operation_id).await?;
let (state, _) = resolve_session_state(
this.state.as_deref(),
this.session_fallback_policy,
|| create_session().state(),
SessionResolveContext {
operation: "merge",
table_uri: Some(this.log_store.root_url()),
cdc: false,
},
)?;
update_datafusion_session(&state, this.log_store.as_ref(), Some(operation_id))?;
let (snapshot, metrics) = execute(
this.predicate,
this.source,
this.log_store.clone(),
snapshot,
state,
this.writer_properties,
this.commit_properties,
this.safe_cast,
this.streaming,
this.source_alias,
this.target_alias,
this.merge_schema,
this.match_operations,
this.not_match_operations,
this.not_match_source_operations,
operation_id,
this.custom_execute_handler.as_ref(),
)
.await?;
if let Some(handler) = this.custom_execute_handler {
handler.post_execute(&this.log_store, operation_id).await?;
}
Ok((
DeltaTable::new_with_state(this.log_store, DeltaTableState { snapshot }),
metrics,
))
})
}
}
#[cfg(test)]
mod tests {
use crate::DeltaTable;
use crate::TableProperty;
use crate::kernel::{Action, DataType, PrimitiveType, StructField};
use crate::operations::merge::filter::generalize_filter;
use crate::protocol::*;
use crate::writer::test_utils::datafusion::{get_data, get_data_sorted};
use crate::writer::test_utils::get_arrow_schema;
use crate::writer::test_utils::get_delta_schema;
use crate::writer::test_utils::setup_table_with_configuration;
use arrow::datatypes::Schema as ArrowSchema;
use arrow::record_batch::RecordBatch;
use arrow_schema::DataType as ArrowDataType;
use arrow_schema::Field;
use dashmap::DashSet;
use datafusion::assert_batches_sorted_eq;
use datafusion::common::{Column, ScalarValue, TableReference, ToDFSchema};
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::col;
use datafusion::logical_expr::expr::BinaryExpr;
use datafusion::logical_expr::expr::Placeholder;
use datafusion::logical_expr::lit;
use datafusion::physical_plan::collect;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion::physical_plan::metrics::MetricBuilder;
use datafusion::prelude::*;
use delta_kernel::engine::arrow_conversion::TryIntoKernel;
use delta_kernel::schema::StructType;
use itertools::Itertools;
use pretty_assertions::assert_eq;
use regex::Regex;
use serde_json::json;
use std::ops::Neg;
use std::sync::Arc;
use url::Url;
use crate::delta_datafusion::{DataFusionMixins, PATH_COLUMN, resolve_file_column_name};
use super::MergeMetrics;
pub(crate) async fn setup_table(partitions: Option<Vec<&str>>) -> DeltaTable {
let table_schema = get_delta_schema();
let table = DeltaTable::new_in_memory()
.create()
.with_columns(table_schema.fields().cloned())
.with_partition_columns(partitions.unwrap_or_default())
.await
.unwrap();
assert_eq!(table.version(), Some(0));
table
}
#[tokio::test]
async fn test_merge_early_filter_does_not_row_filter_rewritten_files() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A"])),
Arc::new(arrow::array::Int32Array::from(vec![999])),
Arc::new(arrow::array::StringArray::from(vec!["2021-02-01"])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let predicate = col("target.id")
.eq(col("source.id"))
.and(col("target.modified").eq(col("source.modified")));
let (table, metrics) = table
.merge(source, predicate)
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| update.update("value", col("source.value")))
.unwrap()
.await
.unwrap();
assert_eq!(metrics.num_target_files_scanned, 1);
assert_eq!(metrics.num_target_files_skipped_during_scan, 1);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 999 | 2021-02-01 |",
"| B | 10 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_metrics_derive_skipped_files_when_scan_skip_metric_missing() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
let pre_merge_files = table.snapshot().unwrap().log_data().num_files();
assert_eq!(pre_merge_files, 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A"])),
Arc::new(arrow::array::Int32Array::from(vec![999])),
Arc::new(arrow::array::StringArray::from(vec!["2021-02-01"])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let predicate = col("target.id")
.eq(col("source.id"))
.and(col("target.modified").eq(col("source.modified")));
let (_table, metrics) = table
.merge(source, predicate)
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| update.update("value", col("source.value")))
.unwrap()
.await
.unwrap();
assert_eq!(metrics.num_target_files_scanned, 1);
assert_eq!(
metrics.num_target_files_skipped_during_scan,
pre_merge_files.saturating_sub(metrics.num_target_files_scanned)
);
assert_eq!(metrics.num_target_files_skipped_during_scan, 1);
}
#[test]
fn test_build_file_skipping_predicates_splits_conjunctions() {
let target = TableReference::parse_str("target");
let filter = col(Column::new(Some(target.clone()), "id"))
.eq(lit("A"))
.and(col(Column::new(Some(target.clone()), "modified")).eq(lit("2021-02-01")));
let predicates = super::build_file_skipping_predicates(Some(filter), Some("target"));
assert_eq!(predicates.len(), 2);
assert_eq!(
predicates[0],
col(Column::new_unqualified("id")).eq(lit("A"))
);
assert_eq!(
predicates[1],
col(Column::new_unqualified("modified")).eq(lit("2021-02-01"))
);
}
#[test]
fn test_build_file_skipping_predicates_none_returns_empty() {
let predicates = super::build_file_skipping_predicates(None, Some("target"));
assert!(predicates.is_empty());
}
#[test]
fn test_get_metric_any_or_returns_first_matching_metric() {
let metrics = ExecutionPlanMetricsSet::new();
MetricBuilder::new(&metrics)
.global_counter("files_scanned")
.add(7);
MetricBuilder::new(&metrics)
.global_counter("count_files_scanned")
.add(3);
let value = super::get_metric_any_or(
&metrics.clone_inner(),
&["count_files_scanned", "files_scanned"],
|| 0,
);
assert_eq!(value, 3);
}
#[test]
fn test_get_metric_any_or_uses_fallback_when_missing() {
let metrics = ExecutionPlanMetricsSet::new();
let value = super::get_metric_any_or(
&metrics.clone_inner(),
&["count_files_pruned", "files_pruned"],
|| 11,
);
assert_eq!(value, 11);
}
#[test]
fn test_derive_skipped_file_count_uses_difference_when_scanned_within_total() {
let (derived, impossible_state) = super::derive_skipped_file_count(5, 3);
assert_eq!(derived, 2);
assert!(!impossible_state);
}
#[test]
fn test_derive_skipped_file_count_clamps_when_scanned_exceeds_total() {
let (derived, impossible_state) = super::derive_skipped_file_count(2, 5);
assert_eq!(derived, 0);
assert!(impossible_state);
}
#[test]
fn test_merge_remove_action_matching_normalizes_relative_paths() {
let survivors = Arc::new(DashSet::new());
survivors.insert("memory://merge-table/part-0001.parquet".to_string());
let table_root = Url::parse("memory://merge-table").unwrap();
assert!(
super::should_remove_rewritten_file(&survivors, "part-0001.parquet", &table_root,)
.unwrap()
);
}
#[test]
fn test_merge_remove_action_matching_does_not_false_positive_on_unrelated_relative_path() {
let survivors = Arc::new(DashSet::new());
survivors.insert("memory://merge-table/part-9999.parquet".to_string());
let table_root = Url::parse("memory://merge-table").unwrap();
assert!(
!super::should_remove_rewritten_file(&survivors, "part-0001.parquet", &table_root,)
.unwrap()
);
}
#[test]
fn test_merge_remove_action_matching_returns_error_when_root_cannot_join_path() {
let survivors = Arc::new(DashSet::new());
let table_root = Url::parse("mailto:owner@example.com").unwrap();
let err = super::should_remove_rewritten_file(&survivors, "part-0001.parquet", &table_root)
.expect_err("expected invalid path normalization to fail");
assert!(
err.to_string().contains("Failed to normalize"),
"unexpected error: {err}"
);
}
#[tokio::test]
async fn test_merge_rewrite_removes_old_file_and_avoids_duplicate_rows() {
let (table, source) = setup().await;
let original_paths: Vec<String> = table
.snapshot()
.unwrap()
.log_data()
.into_iter()
.map(|add| add.path().to_string())
.collect();
assert!(!original_paths.is_empty());
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert!(metrics.num_target_files_added >= 1);
assert!(metrics.num_target_files_removed >= 1);
let snapshot_bytes = table
.log_store
.read_commit_entry(2)
.await
.unwrap()
.expect("failed to get snapshot bytes");
let actions = crate::logstore::get_actions(2, &snapshot_bytes).unwrap();
let removed_paths: Vec<_> = actions
.iter()
.filter_map(|action| match action {
Action::Remove(remove) => Some(remove.path.clone()),
_ => None,
})
.collect();
assert!(!removed_paths.is_empty());
assert!(
original_paths
.iter()
.any(|path| removed_paths.contains(path)),
"Expected at least one rewritten source file to be removed",
);
let table_for_query =
DeltaTable::new_with_state(table.log_store.clone(), table.snapshot().unwrap().clone());
let ctx = SessionContext::new();
table_for_query
.update_datafusion_session(&ctx.state())
.unwrap();
ctx.register_table("test", table_for_query.table_provider().await.unwrap())
.unwrap();
let duplicate_rows = ctx
.sql(
"SELECT id, value, modified, COUNT(*) AS cnt \
FROM test \
GROUP BY id, value, modified \
HAVING cnt > 1",
)
.await
.unwrap()
.collect()
.await
.unwrap();
let duplicate_count: usize = duplicate_rows.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(
duplicate_count, 0,
"Expected merge output without duplicate rows"
);
}
async fn assert_merge_encoded_partition_value_removes_original_file(
partition_value: &str,
expected_raw_encoded_segment: &str,
) {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let make_source = || {
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C"])),
Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])),
Arc::new(arrow::array::StringArray::from(vec![
partition_value,
partition_value,
partition_value,
])),
],
)
.unwrap();
ctx.read_batch(batch).unwrap()
};
let predicate = col("target.modified")
.eq(lit(partition_value))
.and(col("target.id").eq(col("source.id")));
let (table, first_metrics) = table
.merge(make_source(), predicate.clone())
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(first_metrics.num_target_rows_inserted, 3);
assert_eq!(first_metrics.num_target_files_removed, 0);
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
let original_file = table
.snapshot()
.unwrap()
.log_data()
.into_iter()
.next()
.unwrap();
let original_path = original_file.path().to_string();
let original_path_raw = original_file.path_raw().to_string();
assert!(
original_path_raw.contains(expected_raw_encoded_segment),
"expected raw encoded path to contain {expected_raw_encoded_segment}, got {original_path_raw}"
);
let (table, second_metrics) = table
.merge(make_source(), predicate)
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(second_metrics.num_target_rows_updated, 3);
assert_eq!(second_metrics.num_target_files_removed, 1);
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
let snapshot_bytes = table
.log_store
.read_commit_entry(2)
.await
.unwrap()
.expect("failed to get snapshot bytes");
let actions = crate::logstore::get_actions(2, &snapshot_bytes).unwrap();
let removed_paths: Vec<_> = actions
.iter()
.filter_map(|action| match action {
Action::Remove(remove) => Some(remove.path.clone()),
_ => None,
})
.collect();
assert_eq!(removed_paths, vec![original_path]);
let expected = vec![
"+----+-------+------------+".to_string(),
"| id | value | modified |".to_string(),
"+----+-------+------------+".to_string(),
format!("| A | 1 | {partition_value} |"),
format!("| B | 2 | {partition_value} |"),
format!("| C | 3 | {partition_value} |"),
"+----+-------+------------+".to_string(),
];
let expected_refs: Vec<_> = expected.iter().map(String::as_str).collect();
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected_refs, &actual);
}
#[tokio::test]
async fn test_merge_partition_value_with_space_removes_original_file() {
assert_merge_encoded_partition_value_removes_original_file("2021 02 01", "%2520").await;
}
#[tokio::test]
async fn test_merge_partition_value_with_slash_removes_original_file() {
assert_merge_encoded_partition_value_removes_original_file("2021/02/01", "%252F").await;
}
#[tokio::test]
async fn test_merge_partition_value_with_percent_removes_original_file() {
assert_merge_encoded_partition_value_removes_original_file("2021%02%01", "%2525").await;
}
#[tokio::test]
async fn test_merge_when_delta_table_is_append_only() {
let schema = get_arrow_schema(&None);
let table = setup_table_with_configuration(TableProperty::AppendOnly, Some("true")).await;
let table = write_data(table, &schema).await;
let _err = table
.merge(merge_source(schema), col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_not_matched_by_source_delete(|delete| delete)
.unwrap()
.await
.expect_err("Remove action is included when Delta table is append-only. Should error");
}
async fn write_data(table: DeltaTable, schema: &Arc<ArrowSchema>) -> DeltaTable {
let batch = RecordBatch::try_new(
Arc::clone(schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C", "D"])),
Arc::new(arrow::array::Int32Array::from(vec![1, 10, 10, 100])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-01",
"2021-02-01",
"2021-02-02",
"2021-02-02",
])),
],
)
.unwrap();
table
.write(vec![batch.clone()])
.with_save_mode(SaveMode::Append)
.await
.unwrap()
}
async fn write_data_struct(table: DeltaTable, schema: &Arc<ArrowSchema>) -> DeltaTable {
let count_array = arrow::array::Int64Array::from(vec![Some(1), Some(2), Some(3), Some(4)]);
let nested_schema = Arc::new(ArrowSchema::new(vec![Field::new(
"count",
ArrowDataType::Int64,
true,
)]));
let batch = RecordBatch::try_new(
Arc::clone(schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C", "D"])),
Arc::new(arrow::array::Int32Array::from(vec![1, 10, 10, 100])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-01",
"2021-02-01",
"2021-02-02",
"2021-02-02",
])),
Arc::new(arrow::array::StructArray::from(
RecordBatch::try_new(nested_schema, vec![Arc::new(count_array)]).unwrap(),
)),
],
)
.unwrap();
table
.write(vec![batch.clone()])
.with_schema_mode(crate::operations::write::SchemaMode::Overwrite)
.with_save_mode(SaveMode::Overwrite)
.await
.unwrap()
}
fn merge_source(schema: Arc<ArrowSchema>) -> DataFrame {
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
ctx.read_batch(batch).unwrap()
}
async fn setup() -> (DeltaTable, DataFrame) {
let schema = get_arrow_schema(&None);
let table = setup_table(None).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
(table, merge_source(schema))
}
async fn assert_latest_commit_has_metadata_action(table: &DeltaTable, expected: bool) {
let version = table.version().expect("expected merge commit version");
let snapshot_bytes = table
.log_store
.read_commit_entry(version)
.await
.unwrap()
.expect("failed to get snapshot bytes");
let actions = crate::logstore::get_actions(version, &snapshot_bytes).unwrap();
let has_metadata_action = actions
.iter()
.any(|action| matches!(action, Action::Metadata(_)));
assert_eq!(has_metadata_action, expected);
}
async fn assert_merge(table: DeltaTable, metrics: MergeMetrics) {
assert_eq!(table.version(), Some(2));
assert!(table.snapshot().unwrap().log_data().num_files() >= 1);
assert!(metrics.num_target_files_added >= 1);
assert_eq!(metrics.num_target_files_removed, 1);
assert_eq!(metrics.num_target_rows_copied, 1);
assert_eq!(metrics.num_target_rows_updated, 3);
assert_eq!(metrics.num_target_rows_inserted, 1);
assert_eq!(metrics.num_target_rows_deleted, 0);
assert_eq!(metrics.num_output_rows, 5);
assert_eq!(metrics.num_source_rows, 3);
assert_ne!(
metrics.scan_time_ms, 0,
"Expected the scan time to be non-zero"
);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 2 | 2021-02-01 |",
"| B | 10 | 2021-02-02 |",
"| C | 20 | 2023-07-04 |",
"| D | 100 | 2021-02-02 |",
"| X | 30 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[test]
fn test_normalize_target_subset_filter_coerces_decimal_literals() {
let schema = ArrowSchema::new(vec![Field::new(
"altitude",
ArrowDataType::Decimal128(6, 1),
true,
)])
.to_dfschema()
.unwrap();
let normalized = super::normalize_target_subset_filter(
Arc::new(schema),
col("altitude").eq(lit(ScalarValue::Decimal128(Some(1505), 4, 1))),
)
.unwrap();
match normalized {
Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
Expr::Literal(value, _) => {
assert_eq!(value, &ScalarValue::Decimal128(Some(1505), 6, 1));
}
other => panic!("expected decimal literal, got {other:?}"),
},
other => panic!("expected binary expr, got {other:?}"),
}
}
#[tokio::test]
async fn test_merge_with_user_path_column_namespace_collision() {
let delta_schema = vec![
StructField::new(
"id".to_string(),
DataType::Primitive(PrimitiveType::String),
false,
),
StructField::new(
"value".to_string(),
DataType::Primitive(PrimitiveType::Integer),
true,
),
StructField::new(
"modified".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
StructField::new(
PATH_COLUMN.to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
StructField::new(
format!("{PATH_COLUMN}_1"),
DataType::Primitive(PrimitiveType::String),
true,
),
];
let table = DeltaTable::new_in_memory()
.create()
.with_columns(delta_schema)
.await
.unwrap();
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, false),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new(PATH_COLUMN, ArrowDataType::Utf8, true),
Field::new(format!("{PATH_COLUMN}_1"), ArrowDataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B"])),
Arc::new(arrow::array::Int32Array::from(vec![1, 2])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-01",
"2021-02-02",
])),
Arc::new(arrow::array::StringArray::from(vec!["alpha", "beta"])),
Arc::new(arrow::array::StringArray::from(vec!["alpha-1", "beta-1"])),
],
)
.unwrap();
let table = table
.write(vec![batch])
.with_save_mode(SaveMode::Append)
.await
.unwrap();
let file_column_name = resolve_file_column_name(
table.snapshot().unwrap().snapshot().input_schema().as_ref(),
None,
)
.unwrap();
assert_eq!(file_column_name, format!("{PATH_COLUMN}_2"));
let ctx = SessionContext::new();
let source_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B"])),
Arc::new(arrow::array::Int32Array::from(vec![20])),
Arc::new(arrow::array::StringArray::from(vec!["2021-03-01"])),
Arc::new(arrow::array::StringArray::from(vec!["beta-src"])),
Arc::new(arrow::array::StringArray::from(vec!["beta-src-1"])),
],
)
.unwrap();
let source = ctx.read_batch(source_batch).unwrap();
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(metrics.num_target_rows_updated, 1);
assert_eq!(metrics.num_target_rows_inserted, 0);
assert_eq!(metrics.num_target_rows_deleted, 0);
assert!(table.snapshot().unwrap().log_data().num_files() >= 1);
}
#[tokio::test]
async fn test_merge_metrics_select_target_scan_when_source_is_delta_with_same_file_column_name()
{
let target_dir = tempfile::tempdir().unwrap();
let target_path = std::fs::canonicalize(target_dir.path()).unwrap();
let target_url = Url::from_directory_path(&target_path).unwrap();
let target_table = DeltaTable::try_from_url(target_url)
.await
.unwrap()
.create()
.with_columns(get_delta_schema().fields().cloned())
.await
.unwrap();
let target_table = write_data(target_table, &get_arrow_schema(&None)).await;
assert_eq!(target_table.snapshot().unwrap().log_data().num_files(), 1);
let target_file_column = resolve_file_column_name(
target_table
.snapshot()
.unwrap()
.snapshot()
.input_schema()
.as_ref(),
None,
)
.unwrap();
assert_eq!(target_file_column, PATH_COLUMN);
let source_dir = tempfile::tempdir().unwrap();
let source_path = std::fs::canonicalize(source_dir.path()).unwrap();
let source_url = Url::from_directory_path(&source_path).unwrap();
let source_table = DeltaTable::try_from_url(source_url)
.await
.unwrap()
.create()
.with_columns(get_delta_schema().fields().cloned())
.await
.unwrap();
let source_schema = get_arrow_schema(&None);
let source_batch_1 = RecordBatch::try_new(
Arc::clone(&source_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B"])),
Arc::new(arrow::array::Int32Array::from(vec![20])),
Arc::new(arrow::array::StringArray::from(vec!["2021-03-01"])),
],
)
.unwrap();
let source_batch_2 = RecordBatch::try_new(
Arc::clone(&source_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["X"])),
Arc::new(arrow::array::Int32Array::from(vec![30])),
Arc::new(arrow::array::StringArray::from(vec!["2021-03-02"])),
],
)
.unwrap();
let source_table = source_table
.write(vec![source_batch_1])
.with_save_mode(SaveMode::Append)
.await
.unwrap();
let source_table = source_table
.write(vec![source_batch_2])
.with_save_mode(SaveMode::Append)
.await
.unwrap();
assert_eq!(source_table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
source_table
.update_datafusion_session(&ctx.state())
.unwrap();
ctx.register_table(
"source_table",
source_table
.table_provider()
.with_file_column(PATH_COLUMN)
.await
.unwrap(),
)
.unwrap();
let source = ctx
.sql("SELECT id, value, modified FROM source_table")
.await
.unwrap();
let (table, metrics) = target_table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(metrics.num_target_files_scanned, 1);
assert_eq!(metrics.num_target_files_skipped_during_scan, 0);
assert_eq!(metrics.num_target_rows_updated, 1);
assert_eq!(metrics.num_target_rows_inserted, 1);
let actual = get_data_sorted(&table, "id, value, modified").await;
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 20 | 2021-03-01 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"| X | 30 | 2021-03-02 |",
"+----+-------+------------+",
];
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge() {
let (table, source) = setup().await;
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate(col("target.value").eq(lit(1)))
.update("value", col("target.value") + lit(1))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert!(!parameters.contains_key("predicate"));
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["matchedPredicates"],
json!(r#"[{"actionType":"update"}]"#)
);
assert_eq!(
parameters["notMatchedPredicates"],
json!(r#"[{"actionType":"insert"}]"#)
);
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"update","predicate":"target.value = 1"}]"#)
);
assert_merge(table, metrics).await;
}
#[tokio::test]
async fn test_merge_preserves_nullability_without_schema_merge() {
let delta_schema = vec![
StructField::new(
"id".to_string(),
DataType::Primitive(PrimitiveType::String),
false, ),
StructField::new(
"value".to_string(),
DataType::Primitive(PrimitiveType::Integer),
false, ),
StructField::new(
"modified".to_string(),
DataType::Primitive(PrimitiveType::String),
true, ),
];
let table = DeltaTable::new_in_memory()
.create()
.with_save_mode(SaveMode::ErrorIfExists)
.with_columns(delta_schema)
.await
.unwrap();
let initial_fields: Vec<_> = table
.snapshot()
.unwrap()
.schema()
.fields()
.cloned()
.collect();
assert!(
!initial_fields[0].is_nullable(),
"id should be non-nullable"
);
assert!(
!initial_fields[1].is_nullable(),
"value should be non-nullable"
);
assert!(
initial_fields[2].is_nullable(),
"modified should be nullable"
);
let source_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true), Field::new("value", ArrowDataType::Int32, true), Field::new("modified", ArrowDataType::Utf8, true), ]));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
source_schema,
vec![
Arc::new(arrow::array::StringArray::from(vec![Some("A"), Some("B")])),
Arc::new(arrow::array::Int32Array::from(vec![Some(1), Some(2)])),
Arc::new(arrow::array::StringArray::from(vec![
Some("2021-02-02"),
None,
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (merged_table, _) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
let schema = merged_table.snapshot().unwrap().schema();
let final_fields: Vec<_> = schema.fields().collect();
assert!(
!final_fields[0].is_nullable(),
"id should remain non-nullable after merge"
);
assert!(
!final_fields[1].is_nullable(),
"value should remain non-nullable after merge"
);
assert!(
final_fields[2].is_nullable(),
"modified should remain nullable after merge"
);
}
#[tokio::test]
async fn test_merge_with_schema_merge_no_change_of_schema() {
let (table, _) = setup().await;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::LargeUtf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
]));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::LargeStringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (after_table, metrics) = table
.clone()
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate(col("target.value").eq(lit(1)))
.update("value", col("target.value") + lit(1))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
let last_commit = after_table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert!(!parameters.contains_key("predicate"));
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["matchedPredicates"],
json!(r#"[{"actionType":"update"}]"#)
);
assert_eq!(
parameters["notMatchedPredicates"],
json!(r#"[{"actionType":"insert"}]"#)
);
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"update","predicate":"target.value = 1"}]"#)
);
assert_eq!(
table.snapshot().unwrap().schema(),
after_table.snapshot().unwrap().schema()
);
assert_latest_commit_has_metadata_action(&after_table, false).await;
assert_merge(after_table, metrics).await;
}
#[tokio::test]
async fn test_merge_with_schema_merge_partitioned_string_view_source() {
let schema = get_arrow_schema(&None);
let before_table = setup_table(Some(vec!["modified"])).await;
let before_table = write_data(before_table, &schema).await;
let source_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8View, true),
]));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&source_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 30])),
Arc::new(arrow::array::StringViewArray::from(vec![
"2021-02-02",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (after_table, _) = before_table
.clone()
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(
before_table.snapshot().unwrap().schema(),
after_table.snapshot().unwrap().schema()
);
assert_latest_commit_has_metadata_action(&after_table, false).await;
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 10 | 2021-02-02 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"| X | 30 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&after_table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_with_schema_merge_and_struct() {
let (table, _) = setup().await;
let nested_schema = Arc::new(ArrowSchema::new(vec![Field::new(
"count",
ArrowDataType::Int64,
true,
)]));
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new(
"nested",
ArrowDataType::Struct(nested_schema.fields().clone()),
true,
),
]));
let count_array = arrow::array::Int64Array::from(vec![Some(1)]);
let id_array = arrow::array::StringArray::from(vec![Some("X")]);
let value_array = arrow::array::Int32Array::from(vec![Some(1)]);
let modified_array = arrow::array::StringArray::from(vec![Some("2021-02-02")]);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(id_array),
Arc::new(value_array),
Arc::new(modified_array),
Arc::new(arrow::array::StructArray::from(
RecordBatch::try_new(nested_schema, vec![Arc::new(count_array)]).unwrap(),
)),
],
)
.unwrap();
let ctx = SessionContext::new();
let source = ctx.read_batch(batch).unwrap();
let (table, _) = table
.clone()
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("nested", col("source.nested"))
})
.unwrap()
.await
.unwrap();
assert_latest_commit_has_metadata_action(&table, true).await;
let expected = vec![
"+----+-------+------------+------------+",
"| id | value | modified | nested |",
"+----+-------+------------+------------+",
"| A | 1 | 2021-02-01 | |",
"| B | 10 | 2021-02-01 | |",
"| C | 10 | 2021-02-02 | |",
"| D | 100 | 2021-02-02 | |",
"| X | 1 | 2021-02-02 | {count: 1} |",
"+----+-------+------------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_with_schema_merge_and_pre_existing_struct_added_column() {
let table = setup_table(None).await;
let nested_schema = Arc::new(ArrowSchema::new(vec![Field::new(
"count",
ArrowDataType::Int64,
true,
)]));
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new(
"nested",
ArrowDataType::Struct(nested_schema.fields().clone()),
true,
),
]));
let table_with_struct = write_data_struct(table, &schema).await;
let nested_schema_source = Arc::new(ArrowSchema::new(vec![Field::new(
"name",
ArrowDataType::Utf8,
true,
)]));
let schema_source = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new(
"nested",
ArrowDataType::Struct(nested_schema_source.fields().clone()),
true,
),
]));
let name_array = arrow::array::StringArray::from(vec![Some("John")]);
let id_array = arrow::array::StringArray::from(vec![Some("X")]);
let value_array = arrow::array::Int32Array::from(vec![Some(1)]);
let modified_array = arrow::array::StringArray::from(vec![Some("2021-02-02")]);
let batch = RecordBatch::try_new(
schema_source,
vec![
Arc::new(id_array),
Arc::new(value_array),
Arc::new(modified_array),
Arc::new(arrow::array::StructArray::from(
RecordBatch::try_new(nested_schema_source, vec![Arc::new(name_array)]).unwrap(),
)),
],
)
.unwrap();
let ctx = SessionContext::new();
let source = ctx.read_batch(batch).unwrap();
let (table, _) = table_with_struct
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("nested", col("source.nested"))
})
.unwrap()
.await
.unwrap();
assert_latest_commit_has_metadata_action(&table, true).await;
let expected = vec![
"+----+-------+------------+-----------------------+",
"| id | value | modified | nested |",
"+----+-------+------------+-----------------------+",
"| A | 1 | 2021-02-01 | {count: 1, name: } |",
"| B | 10 | 2021-02-01 | {count: 2, name: } |",
"| C | 10 | 2021-02-02 | {count: 3, name: } |",
"| D | 100 | 2021-02-02 | {count: 4, name: } |",
"| X | 1 | 2021-02-02 | {count: , name: John} |",
"+----+-------+------------+-----------------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_schema_evolution_simple_update() {
let (table, _) = setup().await;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new("inserted_by", ArrowDataType::Utf8, true),
]));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![50, 200, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, _) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_matched_update(|update| {
update
.update("value", col("source.value").add(lit(1)))
.update("modified", col("source.modified"))
.update("inserted_by", col("source.inserted_by"))
})
.unwrap()
.await
.unwrap();
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
let expected = vec![
"+----+-------+------------+-------------+",
"| id | value | modified | inserted_by |",
"+----+-------+------------+-------------+",
"| A | 1 | 2021-02-01 | |",
"| B | 51 | 2021-02-02 | B1 |",
"| C | 201 | 2023-07-04 | C1 |",
"| D | 100 | 2021-02-02 | |",
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = Arc::clone(&schema).try_into_kernel().unwrap();
assert_eq!(
&expected_schema_struct,
table.snapshot().unwrap().schema().as_ref()
);
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_schema_evolution_simple_update_with_simple_insert() {
let (table, _) = setup().await;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new("inserted_by", ArrowDataType::Utf8, true),
]));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![50, 200, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, _) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_matched_update(|update| {
update
.update("value", col("source.value").add(lit(1)))
.update("modified", col("source.modified"))
.update("inserted_by", col("source.inserted_by"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("inserted_by", "source.inserted_by")
})
.unwrap()
.await
.unwrap();
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
let expected = vec![
"+----+-------+------------+-------------+",
"| id | value | modified | inserted_by |",
"+----+-------+------------+-------------+",
"| A | 1 | 2021-02-01 | |",
"| B | 51 | 2021-02-02 | B1 |",
"| C | 201 | 2023-07-04 | C1 |",
"| D | 100 | 2021-02-02 | |",
"| X | 30 | 2023-07-04 | X1 |",
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = Arc::clone(&schema).try_into_kernel().unwrap();
assert_eq!(
&expected_schema_struct,
table.snapshot().unwrap().schema().as_ref()
);
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_schema_evolution_simple_insert_with_simple_update() {
let (table, _) = setup().await;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new("inserted_by", ArrowDataType::Utf8, true),
]));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![50, 200, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, _) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("inserted_by", "source.inserted_by")
})
.unwrap()
.when_matched_update(|update| {
update
.update("value", col("source.value").add(lit(1)))
.update("modified", col("source.modified"))
.update("inserted_by", col("source.inserted_by"))
})
.unwrap()
.await
.unwrap();
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
let expected = vec![
"+----+-------+------------+-------------+",
"| id | value | modified | inserted_by |",
"+----+-------+------------+-------------+",
"| A | 1 | 2021-02-01 | |",
"| B | 51 | 2021-02-02 | B1 |",
"| C | 201 | 2023-07-04 | C1 |",
"| D | 100 | 2021-02-02 | |",
"| X | 30 | 2023-07-04 | X1 |",
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = Arc::clone(&schema).try_into_kernel().unwrap();
assert_eq!(
&expected_schema_struct,
table.snapshot().unwrap().schema().as_ref()
);
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_schema_evolution_simple_insert() {
let (table, _) = setup().await;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new("inserted_by", ArrowDataType::Utf8, true),
]));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, _) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("inserted_by", "source.inserted_by")
})
.unwrap()
.await
.unwrap();
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["notMatchedPredicates"],
json!(r#"[{"actionType":"insert"}]"#)
);
let expected = vec![
"+----+-------+------------+-------------+",
"| id | value | modified | inserted_by |",
"+----+-------+------------+-------------+",
"| A | 1 | 2021-02-01 | |",
"| B | 10 | 2021-02-01 | |",
"| C | 10 | 2021-02-02 | |",
"| D | 100 | 2021-02-02 | |",
"| X | 30 | 2023-07-04 | X1 |",
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = Arc::clone(&schema).try_into_kernel().unwrap();
assert_eq!(
&expected_schema_struct,
table.snapshot().unwrap().schema().as_ref()
);
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_str() {
let (table, source) = setup().await;
let (table, metrics) = table
.merge(source, "target.id = source.id")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("target.value", "source.value")
.update("modified", "source.modified")
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate("target.value = 1")
.update("value", "target.value + cast(1 as int)")
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("target.id", "source.id")
.set("value", "source.value")
.set("modified", "source.modified")
})
.unwrap()
.await
.unwrap();
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert!(!parameters.contains_key("predicate"));
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["matchedPredicates"],
json!(r#"[{"actionType":"update"}]"#)
);
assert_eq!(
parameters["notMatchedPredicates"],
json!(r#"[{"actionType":"insert"}]"#)
);
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"update","predicate":"target.value = 1"}]"#)
);
assert_merge(table, metrics).await;
}
#[tokio::test]
async fn test_merge_no_alias() {
let (table, source) = setup().await;
let source = source
.with_column_renamed("id", "source_id")
.unwrap()
.with_column_renamed("value", "source_value")
.unwrap()
.with_column_renamed("modified", "source_modified")
.unwrap();
let (table, metrics) = table
.merge(source, "id = source_id")
.when_matched_update(|update| {
update
.update("value", "source_value")
.update("modified", "source_modified")
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update.predicate("value = 1").update("value", "value + 1")
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", "source_id")
.set("value", "source_value")
.set("modified", "source_modified")
})
.unwrap()
.await
.unwrap();
assert_merge(table, metrics).await;
}
#[tokio::test]
async fn test_merge_with_alias_mix() {
let (table, source) = setup().await;
let source = source
.with_column_renamed("id", "source_id")
.unwrap()
.with_column_renamed("value", "source_value")
.unwrap()
.with_column_renamed("modified", "source_modified")
.unwrap();
let (table, metrics) = table
.merge(source, "id = source_id")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", "source_value")
.update("modified", "source_modified")
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate("value = 1")
.update("value", "target.value + 1")
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", "source_id")
.set("target.value", "source_value")
.set("modified", "source_modified")
})
.unwrap()
.await
.unwrap();
assert_merge(table, metrics).await;
}
#[tokio::test]
async fn test_merge_failures() {
let (table, source) = setup().await;
let res = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("source.value", "source.value")
.update("modified", "source.modified")
})
.unwrap()
.await;
assert!(res.is_err());
let (table, source) = setup().await;
let res = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("source")
.when_matched_update(|update| {
update
.update("target.value", "source.value")
.update("modified", "source.modified")
})
.unwrap()
.await;
assert!(res.is_err())
}
#[tokio::test]
async fn test_merge_update_multiple_source_match_error() {
let schema = get_arrow_schema(&None);
let table = setup_table(None).await;
let table = write_data(table, &schema).await;
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "B"])),
Arc::new(arrow::array::Int32Array::from(vec![11, 12])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-05",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 10 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"+----+-------+------------+",
];
let res = table
.clone()
.merge(source, "target.id = source.id")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", "source.value")
.update("modified", "source.modified")
})
.unwrap()
.await;
let err = res.expect_err("expected duplicate validation failure");
assert!(
err.to_string()
.contains("duplicate relevant WHEN MATCHED clauses")
);
assert_eq!(table.version(), Some(1));
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_update_duplicate_with_noop_source_row_passes() {
let schema = get_arrow_schema(&None);
let table = setup_table(None).await;
let table = write_data(table, &schema).await;
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "B"])),
Arc::new(arrow::array::Int32Array::from(vec![11, 12])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-05",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(source, "target.id = source.id")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.predicate(col("source.value").gt(lit(11)))
.update("value", "source.value")
.update("modified", "source.modified")
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert_eq!(metrics.num_source_rows, 2);
assert_eq!(metrics.num_target_rows_inserted, 0);
assert_eq!(metrics.num_target_rows_updated, 1);
assert_eq!(metrics.num_target_rows_deleted, 0);
assert_eq!(metrics.num_target_rows_copied, 3);
assert_eq!(metrics.num_output_rows, 4);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 12 | 2023-07-05 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_cdf_enabled_update_duplicate_with_noop_source_row_passes() {
use crate::kernel::ProtocolInner;
use crate::operations::merge::Action;
let delta_schema = get_delta_schema();
let actions = vec![Action::Protocol(ProtocolInner::new(1, 4).as_kernel())];
let table = DeltaTable::new_in_memory()
.create()
.with_columns(delta_schema.fields().cloned())
.with_actions(actions)
.with_configuration_property(TableProperty::EnableChangeDataFeed, Some("true"))
.await
.unwrap();
let schema = get_arrow_schema(&None);
let table = write_data(table, &schema).await;
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "B"])),
Arc::new(arrow::array::Int32Array::from(vec![11, 12])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-05",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(source, "target.id = source.id")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.predicate(col("source.value").gt(lit(11)))
.update("value", "source.value")
.update("modified", "source.modified")
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert_eq!(metrics.num_source_rows, 2);
assert_eq!(metrics.num_target_rows_inserted, 0);
assert_eq!(metrics.num_target_rows_updated, 1);
assert_eq!(metrics.num_target_rows_deleted, 0);
assert_eq!(metrics.num_target_rows_copied, 3);
assert_eq!(metrics.num_output_rows, 4);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 12 | 2023-07-05 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
let cdf = table
.scan_cdf()
.with_starting_version(0)
.build(&ctx.state(), None)
.await
.expect("Failed to load CDF");
let mut batches = collect(cdf, ctx.task_ctx())
.await
.expect("Failed to collect CDF batches");
let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(5)).collect();
let expected_cdf = vec![
"+----+-------+------------+------------------+-----------------+",
"| id | value | modified | _change_type | _commit_version |",
"+----+-------+------------+------------------+-----------------+",
"| A | 1 | 2021-02-01 | insert | 1 |",
"| B | 10 | 2021-02-01 | insert | 1 |",
"| B | 10 | 2021-02-01 | update_preimage | 2 |",
"| B | 12 | 2023-07-05 | update_postimage | 2 |",
"| C | 10 | 2021-02-02 | insert | 1 |",
"| D | 100 | 2021-02-02 | insert | 1 |",
"+----+-------+------------+------------------+-----------------+",
];
assert_batches_sorted_eq!(&expected_cdf, &batches);
}
#[tokio::test]
async fn test_merge_unconditional_delete_multiple_source_match_allowed() {
let schema = get_arrow_schema(&None);
let table = setup_table(None).await;
let table = write_data(table, &schema).await;
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "B"])),
Arc::new(arrow::array::Int32Array::from(vec![11, 12])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-05",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, _) = table
.merge(source, "target.id = source.id")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_delete(|delete| delete)
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_conditional_delete_multiple_source_match_error() {
let schema = get_arrow_schema(&None);
let table = setup_table(None).await;
let table = write_data(table, &schema).await;
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "B"])),
Arc::new(arrow::array::Int32Array::from(vec![11, 12])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-05",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 10 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"+----+-------+------------+",
];
let res = table
.clone()
.merge(source, "target.id = source.id")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_delete(|delete| delete.predicate(col("source.value").gt(lit(10))))
.unwrap()
.await;
let err = res.expect_err("expected duplicate validation failure");
assert!(
err.to_string()
.contains("duplicate relevant WHEN MATCHED clauses")
);
assert_eq!(table.version(), Some(1));
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_partitions() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(
source,
col("target.id")
.eq(col("source.id"))
.and(col("target.modified").eq(lit("2021-02-02"))),
)
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate(col("target.value").eq(lit(1)))
.update("value", col("target.value") + lit(1))
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate(col("target.modified").eq(lit("2021-02-01")))
.update("value", col("target.value") - lit(1))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert!(table.snapshot().unwrap().log_data().num_files() >= 3);
assert!(metrics.num_target_files_added >= 3);
assert_eq!(metrics.num_target_files_removed, 2);
assert_eq!(metrics.num_target_rows_copied, 1);
assert_eq!(metrics.num_target_rows_updated, 3);
assert_eq!(metrics.num_target_rows_inserted, 2);
assert_eq!(metrics.num_target_rows_deleted, 0);
assert_eq!(metrics.num_output_rows, 6);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert!(!parameters.contains_key("predicate"));
assert_eq!(
parameters["mergePredicate"],
"target.id = source.id AND target.modified = '2021-02-02'"
);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 2 | 2021-02-01 |",
"| B | 9 | 2021-02-01 |",
"| B | 10 | 2021-02-02 |",
"| C | 20 | 2023-07-04 |",
"| D | 100 | 2021-02-02 |",
"| X | 30 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_partition_filtered() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2021-02-02",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, _metrics) = table
.merge(
source,
col("target.id")
.eq(col("source.id"))
.and(col("target.modified").eq(lit("2021-02-02"))),
)
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(
parameters["predicate"],
"id >= 'B' AND id <= 'C' AND modified = '2021-02-02'"
);
assert_eq!(
parameters["mergePredicate"],
"target.id = source.id AND target.modified = '2021-02-02'"
);
}
#[tokio::test]
async fn test_merge_partitions_skipping() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["id"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 4);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![999, 999, 999])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert!(table.snapshot().unwrap().log_data().num_files() >= 3);
assert_eq!(metrics.num_target_files_added, 3);
assert_eq!(metrics.num_target_files_removed, 2);
assert_eq!(metrics.num_target_rows_copied, 0);
assert_eq!(metrics.num_target_rows_updated, 2);
assert_eq!(metrics.num_target_rows_inserted, 1);
assert_eq!(metrics.num_target_rows_deleted, 0);
assert_eq!(metrics.num_output_rows, 3);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
let predicate = parameters["predicate"].as_str().unwrap();
let re = Regex::new(r"^id = '(C|X|B)' OR id = '(C|X|B)' OR id = '(C|X|B)'$").unwrap();
assert!(re.is_match(predicate));
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 999 | 2023-07-04 |",
"| C | 999 | 2023-07-04 |",
"| D | 100 | 2021-02-02 |",
"| X | 999 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_partitions_with_in() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(
source,
col("target.id")
.eq(col("source.id"))
.and(col("target.id").in_list(
vec![
col("source.id"),
col("source.modified"),
col("source.value"),
],
false,
))
.and(col("target.modified").in_list(vec![lit("2021-02-02")], false)),
)
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate(col("target.value").eq(lit(1)))
.update("value", col("target.value") + lit(1))
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate(col("target.modified").eq(lit("2021-02-01")))
.update("value", col("target.value") - lit(1))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert!(table.snapshot().unwrap().log_data().num_files() >= 3);
assert!(metrics.num_target_files_added >= 3);
assert_eq!(metrics.num_target_files_removed, 2);
assert_eq!(metrics.num_target_rows_copied, 1);
assert_eq!(metrics.num_target_rows_updated, 3);
assert_eq!(metrics.num_target_rows_inserted, 2);
assert_eq!(metrics.num_target_rows_deleted, 0);
assert_eq!(metrics.num_output_rows, 6);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert!(!parameters.contains_key("predicate"));
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 2 | 2021-02-01 |",
"| B | 9 | 2021-02-01 |",
"| B | 10 | 2021-02-02 |",
"| C | 20 | 2023-07-04 |",
"| D | 100 | 2021-02-02 |",
"| X | 30 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_delete_matched() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_delete(|delete| delete)
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert!(table.snapshot().unwrap().log_data().num_files() >= 2);
assert_eq!(metrics.num_target_files_added, 2);
assert_eq!(metrics.num_target_files_removed, 2);
assert_eq!(metrics.num_target_rows_copied, 2);
assert_eq!(metrics.num_target_rows_updated, 0);
assert_eq!(metrics.num_target_rows_inserted, 0);
assert_eq!(metrics.num_target_rows_deleted, 2);
assert_eq!(metrics.num_output_rows, 2);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
let extra_info = last_commit.info.clone();
assert_eq!(
extra_info["operationMetrics"],
serde_json::to_value(&metrics).unwrap()
);
assert_eq!(parameters["predicate"], "id >= 'B' AND id <= 'X'");
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["matchedPredicates"],
json!(r#"[{"actionType":"delete"}]"#)
);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| D | 100 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_delete(|delete| delete.predicate(col("source.value").lt_eq(lit(10))))
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert!(table.snapshot().unwrap().log_data().num_files() >= 2);
assert_eq!(metrics.num_target_files_added, 1);
assert_eq!(metrics.num_target_files_removed, 1);
assert_eq!(metrics.num_target_rows_copied, 1);
assert_eq!(metrics.num_target_rows_updated, 0);
assert_eq!(metrics.num_target_rows_inserted, 0);
assert_eq!(metrics.num_target_rows_deleted, 1);
assert_eq!(metrics.num_output_rows, 1);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["matchedPredicates"],
json!(r#"[{"actionType":"delete","predicate":"source.value <= 10"}]"#)
);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_delete_not_matched() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_not_matched_by_source_delete(|delete| delete)
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
assert_eq!(metrics.num_target_files_added, 2);
assert_eq!(metrics.num_target_files_removed, 2);
assert_eq!(metrics.num_target_rows_copied, 2);
assert_eq!(metrics.num_target_rows_updated, 0);
assert_eq!(metrics.num_target_rows_inserted, 0);
assert_eq!(metrics.num_target_rows_deleted, 2);
assert_eq!(metrics.num_output_rows, 2);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert!(!parameters.contains_key("predicate"));
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"delete"}]"#)
);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| B | 10 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_not_matched_by_source_delete(|delete| {
delete.predicate(col("target.modified").gt(lit("2021-02-01")))
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert!(metrics.num_target_files_added == 1);
assert_eq!(metrics.num_target_files_removed, 1);
assert_eq!(metrics.num_target_rows_copied, 1);
assert_eq!(metrics.num_target_rows_updated, 0);
assert_eq!(metrics.num_target_rows_inserted, 0);
assert_eq!(metrics.num_target_rows_deleted, 1);
assert_eq!(metrics.num_output_rows, 1);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"delete","predicate":"target.modified > '2021-02-01'"}]"#)
);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 10 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_delete_not_matched_with_schema_merge() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_not_matched_by_source_delete(|delete| delete)
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
assert_eq!(metrics.num_target_files_added, 2);
assert_eq!(metrics.num_target_files_removed, 2);
assert_eq!(metrics.num_target_rows_copied, 2);
assert_eq!(metrics.num_target_rows_updated, 0);
assert_eq!(metrics.num_target_rows_inserted, 0);
assert_eq!(metrics.num_target_rows_deleted, 2);
assert_eq!(metrics.num_output_rows, 2);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert!(!parameters.contains_key("predicate"));
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"delete"}]"#)
);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| B | 10 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 2);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_not_matched_by_source_delete(|delete| {
delete.predicate(col("target.modified").gt(lit("2021-02-01")))
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(2));
assert!(metrics.num_target_files_added == 1);
assert_eq!(metrics.num_target_files_removed, 1);
assert_eq!(metrics.num_target_rows_copied, 1);
assert_eq!(metrics.num_target_rows_updated, 0);
assert_eq!(metrics.num_target_rows_inserted, 0);
assert_eq!(metrics.num_target_rows_deleted, 1);
assert_eq!(metrics.num_output_rows, 1);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"delete","predicate":"target.modified > '2021-02-01'"}]"#)
);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 10 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_empty_table() {
let schema = get_arrow_schema(&None);
let table = setup_table(Some(vec!["modified"])).await;
assert_eq!(table.version(), Some(0));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 0);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(
source,
col("target.id")
.eq(col("source.id"))
.and(col("target.modified").eq(lit("2021-02-02"))),
)
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(1));
assert!(table.snapshot().unwrap().log_data().num_files() >= 2);
assert!(metrics.num_target_files_added >= 2);
assert_eq!(metrics.num_target_files_removed, 0);
assert_eq!(metrics.num_target_rows_copied, 0);
assert_eq!(metrics.num_target_rows_updated, 0);
assert_eq!(metrics.num_target_rows_inserted, 3);
assert_eq!(metrics.num_target_rows_deleted, 0);
assert_eq!(metrics.num_output_rows, 3);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(
parameters["predicate"],
json!("id >= 'B' AND id <= 'X' AND modified = '2021-02-02'")
);
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| B | 10 | 2021-02-02 |",
"| C | 20 | 2023-07-04 |",
"| X | 30 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_empty_table_with_schema_merge() {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new("inserted_by", ArrowDataType::Utf8, true),
]));
let table = setup_table(Some(vec!["modified"])).await;
assert_eq!(table.version(), Some(0));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 0);
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, metrics) = table
.merge(
source,
col("target.id")
.eq(col("source.id"))
.and(col("target.modified").eq(lit("2021-02-02"))),
)
.with_merge_schema(true)
.with_source_alias("source")
.with_target_alias("target")
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("inserted_by", col("source.inserted_by"))
})
.unwrap()
.await
.unwrap();
assert_eq!(table.version(), Some(1));
assert!(table.snapshot().unwrap().log_data().num_files() >= 2);
assert!(metrics.num_target_files_added >= 2);
assert_eq!(metrics.num_target_files_removed, 0);
assert_eq!(metrics.num_target_rows_copied, 0);
assert_eq!(metrics.num_target_rows_updated, 0);
assert_eq!(metrics.num_target_rows_inserted, 3);
assert_eq!(metrics.num_target_rows_deleted, 0);
assert_eq!(metrics.num_output_rows, 3);
assert_eq!(metrics.num_source_rows, 3);
let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(
parameters["predicate"],
json!("id >= 'B' AND id <= 'X' AND modified = '2021-02-02'")
);
let expected = vec![
"+----+-------+------------+-------------+",
"| id | value | modified | inserted_by |",
"+----+-------+------------+-------------+",
"| B | 10 | 2021-02-02 | B1 |",
"| C | 20 | 2023-07-04 | C1 |",
"| X | 30 | 2023-07-04 | X1 |",
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = schema.as_ref().try_into_kernel().unwrap();
assert_eq!(
&expected_schema_struct,
table.snapshot().unwrap().schema().as_ref()
);
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_case_sensitive() {
let schema = vec![
StructField::new(
"Id".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
StructField::new(
"vAlue".to_string(), DataType::Primitive(PrimitiveType::Integer),
true,
),
StructField::new(
"mOdifieD".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
];
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("Id", ArrowDataType::Utf8, true),
Field::new("vAlue", ArrowDataType::Int32, true), Field::new("mOdifieD", ArrowDataType::Utf8, true),
]));
let table = DeltaTable::new_in_memory()
.create()
.with_columns(schema)
.await
.unwrap();
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let table = write_data(table, &arrow_schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
let (table, _metrics) = table
.merge(source, "target.Id = source.Id")
.with_source_alias("source")
.with_target_alias("target")
.when_not_matched_insert(|insert| {
insert
.set("Id", "source.Id")
.set("vAlue", "source.vAlue + 1") .set("mOdifieD", "source.mOdifieD")
})
.unwrap()
.await
.unwrap();
let expected = vec![
"+----+-------+------------+",
"| Id | vAlue | mOdifieD |", "+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 10 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"| D | 100 | 2021-02-02 |",
"| X | 31 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_generalize_filter_with_partitions() {
let source = TableReference::parse_str("source");
let target = TableReference::parse_str("target");
let parsed_filter = col(Column::new(source.clone().into(), "id"))
.eq(col(Column::new(target.clone().into(), "id")));
let mut placeholders = Vec::default();
let generalized = generalize_filter(
parsed_filter,
&vec!["id".to_owned()],
&source,
&target,
&mut placeholders,
false,
)
.unwrap();
let expected_filter = Expr::Placeholder(Placeholder {
id: "id_0".to_owned(),
field: None,
})
.eq(col(Column::new(target.clone().into(), "id")));
assert_eq!(generalized, expected_filter);
}
#[tokio::test]
async fn test_generalize_filter_with_partitions_nulls() {
let source = TableReference::parse_str("source");
let target = TableReference::parse_str("target");
let source_id = col(Column::new(source.clone().into(), "id"));
let target_id = col(Column::new(target.clone().into(), "id"));
let parsed_filter = (source_id.clone().eq(target_id.clone()))
.or(source_id.clone().is_null().and(target_id.clone().is_null()));
let mut placeholders = Vec::default();
let generalized = generalize_filter(
parsed_filter,
&vec!["id".to_owned()],
&source,
&target,
&mut placeholders,
false,
)
.unwrap();
let expected_filter = Expr::Placeholder(Placeholder {
id: "id_0".to_owned(),
field: None,
})
.eq(target_id.clone())
.or(Expr::Placeholder(Placeholder {
id: "id_1".to_owned(),
field: None,
})
.and(target_id.clone().is_null()));
assert_eq!(placeholders.len(), 2);
let captured_expressions = placeholders.into_iter().map(|p| p.expr).collect_vec();
assert!(captured_expressions.contains(&source_id));
assert!(captured_expressions.contains(&source_id.is_null()));
assert_eq!(generalized, expected_filter);
}
#[tokio::test]
async fn test_generalize_filter_with_partitions_captures_expression() {
let source = TableReference::parse_str("source");
let target = TableReference::parse_str("target");
let parsed_filter = col(Column::new(source.clone().into(), "id"))
.neg()
.eq(col(Column::new(target.clone().into(), "id")));
let mut placeholders = Vec::default();
let generalized = generalize_filter(
parsed_filter,
&vec!["id".to_owned()],
&source,
&target,
&mut placeholders,
false,
)
.unwrap();
let expected_filter = Expr::Placeholder(Placeholder {
id: "id_0".to_owned(),
field: None,
})
.eq(col(Column::new(target.clone().into(), "id")));
assert_eq!(generalized, expected_filter);
assert_eq!(placeholders.len(), 1);
let placeholder_expr = placeholders.first().unwrap();
let expected_placeholder = col(Column::new(source.clone().into(), "id")).neg();
assert_eq!(placeholder_expr.expr, expected_placeholder);
assert_eq!(placeholder_expr.alias, "id_0");
assert!(!placeholder_expr.is_aggregate);
}
#[tokio::test]
async fn test_generalize_filter_keeps_static_target_references() {
let source = TableReference::parse_str("source");
let target = TableReference::parse_str("target");
let parsed_filter = col(Column::new(source.clone().into(), "id"))
.eq(col(Column::new(target.clone().into(), "id")))
.and(col(Column::new(target.clone().into(), "id")).eq(lit("C")));
let mut placeholders = Vec::default();
let generalized = generalize_filter(
parsed_filter,
&vec!["id".to_owned()],
&source,
&target,
&mut placeholders,
false,
)
.unwrap();
let expected_filter = Expr::Placeholder(Placeholder {
id: "id_0".to_owned(),
field: None,
})
.eq(col(Column::new(target.clone().into(), "id")))
.and(col(Column::new(target.clone().into(), "id")).eq(lit("C")));
assert_eq!(generalized, expected_filter);
}
#[tokio::test]
async fn test_generalize_filter_with_dynamic_target_range_references() {
let source = TableReference::parse_str("source");
let target = TableReference::parse_str("target");
let parsed_filter = col(Column::new(source.clone().into(), "id"))
.eq(col(Column::new(target.clone().into(), "id")));
let mut placeholders = Vec::default();
let generalized = generalize_filter(
parsed_filter,
&vec!["other".to_owned()],
&source,
&target,
&mut placeholders,
false,
)
.unwrap();
let expected_filter_l = Expr::Placeholder(Placeholder {
id: "id_0_min".to_owned(),
field: None,
});
let expected_filter_h = Expr::Placeholder(Placeholder {
id: "id_0_max".to_owned(),
field: None,
});
let expected_filter = col(Column::new(target.clone().into(), "id"))
.between(expected_filter_l, expected_filter_h);
assert_eq!(generalized, expected_filter);
}
#[tokio::test]
async fn test_generalize_filter_removes_source_references() {
let source = TableReference::parse_str("source");
let target = TableReference::parse_str("target");
let parsed_filter = col(Column::new(source.clone().into(), "id"))
.eq(col(Column::new(target.clone().into(), "id")))
.and(col(Column::new(source.clone().into(), "id")).eq(lit("C")));
let mut placeholders = Vec::default();
let generalized = generalize_filter(
parsed_filter,
&vec!["id".to_owned()],
&source,
&target,
&mut placeholders,
false,
)
.unwrap();
let expected_filter = Expr::Placeholder(Placeholder {
id: "id_0".to_owned(),
field: None,
})
.eq(col(Column::new(target.clone().into(), "id")));
assert_eq!(generalized, expected_filter);
}
#[tokio::test]
async fn test_merge_pushdowns() {
let schema = vec![
StructField::new(
"id".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
StructField::new(
"cost".to_string(),
DataType::Primitive(PrimitiveType::Float),
true,
),
StructField::new(
"month".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
];
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("cost", ArrowDataType::Float32, true),
Field::new("month", ArrowDataType::Utf8, true),
]));
let table = DeltaTable::new_in_memory()
.create()
.with_columns(schema)
.await
.unwrap();
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B"])),
Arc::new(arrow::array::Float32Array::from(vec![Some(10.15), None])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let table = table
.write(vec![batch.clone()])
.with_save_mode(SaveMode::Append)
.await
.unwrap();
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
let batch = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B"])),
Arc::new(arrow::array::Float32Array::from(vec![
Some(12.15),
Some(11.15),
])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, _metrics) = table
.merge(source, "target.id = source.id and target.cost is null")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|insert| {
insert
.update("id", "target.id")
.update("cost", "source.cost")
.update("month", "target.month")
})
.unwrap()
.await
.unwrap();
let expected = vec![
"+----+-------+------------+",
"| id | cost | month |",
"+----+-------+------------+",
"| A | 10.15 | 2023-07-04 |",
"| B | 11.15 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_row_groups_parquet_pushdown() {
let schema = vec![
StructField::new(
"id".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
StructField::new(
"cost".to_string(),
DataType::Primitive(PrimitiveType::Float),
true,
),
StructField::new(
"month".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
];
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("cost", ArrowDataType::Float32, true),
Field::new("month", ArrowDataType::Utf8, true),
]));
let table = DeltaTable::new_in_memory()
.create()
.with_columns(schema)
.await
.unwrap();
let ctx = SessionContext::new();
let batch1 = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B"])),
Arc::new(arrow::array::Float32Array::from(vec![Some(10.15), None])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["C", "D"])),
Arc::new(arrow::array::Float32Array::from(vec![
Some(11.0),
Some(12.0),
])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let table = table
.write(vec![batch1, batch2])
.with_write_batch_size(2)
.with_save_mode(SaveMode::Append)
.await
.unwrap();
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
let batch = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["C", "E"])),
Arc::new(arrow::array::Float32Array::from(vec![
Some(12.15),
Some(11.15),
])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, _metrics) = table
.merge(source, "target.id = source.id and target.id >= 'C'")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|insert| {
insert
.update("id", "target.id")
.update("cost", "source.cost")
.update("month", "target.month")
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", "source.id")
.set("cost", "source.cost")
.set("month", "source.month")
})
.unwrap()
.await
.unwrap();
let expected = vec![
"+----+-------+------------+",
"| id | cost | month |",
"+----+-------+------------+",
"| A | 10.15 | 2023-07-04 |",
"| B | | 2023-07-04 |",
"| C | 12.15 | 2023-07-04 |",
"| D | 12.0 | 2023-07-04 |",
"| E | 11.15 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_pushdowns_partitioned() {
let schema = vec![
StructField::new(
"id".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
StructField::new(
"cost".to_string(),
DataType::Primitive(PrimitiveType::Float),
true,
),
StructField::new(
"month".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
];
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("cost", ArrowDataType::Float32, true),
Field::new("month", ArrowDataType::Utf8, true),
]));
let part_cols = vec!["month"];
let table = DeltaTable::new_in_memory()
.create()
.with_columns(schema)
.with_partition_columns(part_cols)
.await
.unwrap();
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B"])),
Arc::new(arrow::array::Float32Array::from(vec![Some(10.15), None])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let table = table
.write(vec![batch.clone()])
.with_save_mode(SaveMode::Append)
.await
.unwrap();
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
let batch = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B"])),
Arc::new(arrow::array::Float32Array::from(vec![
Some(12.15),
Some(11.15),
])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();
let (table, _metrics) = table
.merge(source, "target.id = source.id and target.cost is null")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|insert| {
insert
.update("id", "target.id")
.update("cost", "source.cost")
.update("month", "target.month")
})
.unwrap()
.await
.unwrap();
let expected = vec![
"+----+-------+------------+",
"| id | cost | month |",
"+----+-------+------------+",
"| A | 10.15 | 2023-07-04 |",
"| B | 11.15 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_cdc_disabled() {
let (table, source) = setup().await;
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate(col("target.value").eq(lit(1)))
.update("value", col("target.value") + lit(1))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_merge(table.clone(), metrics).await;
if let Ok(files) = crate::logstore::tests::flatten_list_stream(
&table.object_store(),
Some(&object_store::path::Path::from("_change_data")),
)
.await
{
assert_eq!(
0,
files.len(),
"This test should not find any written CDC files! {files:#?}"
);
}
}
#[tokio::test]
async fn test_merge_cdc_enabled_simple() {
use crate::kernel::ProtocolInner;
use crate::operations::merge::Action;
let schema = get_delta_schema();
let actions = vec![Action::Protocol(ProtocolInner::new(1, 4).as_kernel())];
let table = DeltaTable::new_in_memory()
.create()
.with_columns(schema.fields().cloned())
.with_actions(actions)
.with_configuration_property(TableProperty::EnableChangeDataFeed, Some("true"))
.await
.unwrap();
assert_eq!(table.version(), Some(0));
let schema = get_arrow_schema(&None);
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
let source = merge_source(schema);
let (table, metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate(col("target.value").eq(lit(1)))
.update("value", col("target.value") + lit(1))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
})
.unwrap()
.await
.unwrap();
assert_merge(table.clone(), metrics).await;
let ctx = SessionContext::new();
let table = table
.scan_cdf()
.with_starting_version(0)
.build(&ctx.state(), None)
.await
.expect("Failed to load CDF");
let mut batches = collect(table, ctx.task_ctx())
.await
.expect("Failed to collect batches");
let _ = arrow::util::pretty::print_batches(&batches);
let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(5)).collect();
assert_batches_sorted_eq! {[
"+----+-------+------------+------------------+-----------------+",
"| id | value | modified | _change_type | _commit_version |",
"+----+-------+------------+------------------+-----------------+",
"| A | 1 | 2021-02-01 | update_preimage | 2 |",
"| A | 2 | 2021-02-01 | update_postimage | 2 |",
"| B | 10 | 2021-02-01 | update_preimage | 2 |",
"| B | 10 | 2021-02-02 | update_postimage | 2 |",
"| C | 10 | 2021-02-02 | update_preimage | 2 |",
"| C | 20 | 2023-07-04 | update_postimage | 2 |",
"| X | 30 | 2023-07-04 | insert | 2 |",
"| A | 1 | 2021-02-01 | insert | 1 |",
"| B | 10 | 2021-02-01 | insert | 1 |",
"| C | 10 | 2021-02-02 | insert | 1 |",
"| D | 100 | 2021-02-02 | insert | 1 |",
"+----+-------+------------+------------------+-----------------+",
], &batches }
}
#[tokio::test]
async fn test_merge_cdc_enabled_simple_with_schema_merge() {
use crate::kernel::ProtocolInner;
use crate::operations::merge::Action;
let schema = get_delta_schema();
let actions = vec![Action::Protocol(ProtocolInner::new(1, 4).as_kernel())];
let table: DeltaTable = DeltaTable::new_in_memory()
.create()
.with_columns(schema.fields().cloned())
.with_actions(actions)
.with_configuration_property(TableProperty::EnableChangeDataFeed, Some("true"))
.await
.unwrap();
assert_eq!(table.version(), Some(0));
let schema = get_arrow_schema(&None);
let source_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new("inserted_by", ArrowDataType::Utf8, true),
]));
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
let source = merge_source(schema);
let source = source.with_column("inserted_by", lit("new_value")).unwrap();
let (table, _) = table
.merge(source.clone(), col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_matched_update(|update| {
update
.update("value", col("source.value"))
.update("modified", col("source.modified"))
})
.unwrap()
.when_not_matched_by_source_update(|update| {
update
.predicate(col("target.value").eq(lit(1)))
.update("value", col("target.value") + lit(1))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("inserted_by", col("source.inserted_by"))
})
.unwrap()
.await
.unwrap();
let expected = vec![
"+----+-------+------------+-------------+",
"| id | value | modified | inserted_by |",
"+----+-------+------------+-------------+",
"| A | 2 | 2021-02-01 | |",
"| B | 10 | 2021-02-02 | new_value |",
"| C | 20 | 2023-07-04 | new_value |",
"| D | 100 | 2021-02-02 | |",
"| X | 30 | 2023-07-04 | new_value |",
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = source_schema.try_into_kernel().unwrap();
assert_eq!(
&expected_schema_struct,
table.snapshot().unwrap().schema().as_ref()
);
assert_batches_sorted_eq!(&expected, &actual);
let ctx = SessionContext::new();
let table = table
.scan_cdf()
.with_starting_version(0)
.build(&ctx.state(), None)
.await
.expect("Failed to load CDF");
let mut batches = collect(table, ctx.task_ctx())
.await
.expect("Failed to collect batches");
let _ = arrow::util::pretty::print_batches(&batches);
let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(6)).collect();
assert_batches_sorted_eq! {[
"+----+-------+------------+-------------+------------------+-----------------+",
"| id | value | modified | inserted_by | _change_type | _commit_version |",
"+----+-------+------------+-------------+------------------+-----------------+",
"| A | 1 | 2021-02-01 | | insert | 1 |",
"| A | 1 | 2021-02-01 | | update_preimage | 2 |",
"| A | 2 | 2021-02-01 | | update_postimage | 2 |",
"| B | 10 | 2021-02-01 | | insert | 1 |",
"| B | 10 | 2021-02-01 | | update_preimage | 2 |",
"| B | 10 | 2021-02-02 | new_value | update_postimage | 2 |",
"| C | 10 | 2021-02-02 | | insert | 1 |",
"| C | 10 | 2021-02-02 | | update_preimage | 2 |",
"| C | 20 | 2023-07-04 | new_value | update_postimage | 2 |",
"| D | 100 | 2021-02-02 | | insert | 1 |",
"| X | 30 | 2023-07-04 | new_value | insert | 2 |",
"+----+-------+------------+-------------+------------------+-----------------+",
], &batches }
}
#[tokio::test]
async fn test_merge_cdc_enabled_delete() {
use crate::kernel::ProtocolInner;
use crate::operations::merge::Action;
let schema = get_delta_schema();
let actions = vec![Action::Protocol(ProtocolInner::new(1, 4).as_kernel())];
let table: DeltaTable = DeltaTable::new_in_memory()
.create()
.with_columns(schema.fields().cloned())
.with_actions(actions)
.with_configuration_property(TableProperty::EnableChangeDataFeed, Some("true"))
.await
.unwrap();
assert_eq!(table.version(), Some(0));
let schema = get_arrow_schema(&None);
let table = write_data(table, &schema).await;
assert_eq!(table.version(), Some(1));
assert_eq!(table.snapshot().unwrap().log_data().num_files(), 1);
let source = merge_source(schema);
let (table, _metrics) = table
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.when_not_matched_by_source_delete(|delete| {
delete.predicate(col("target.modified").gt(lit("2021-02-01")))
})
.unwrap()
.await
.unwrap();
let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
"+----+-------+------------+",
"| A | 1 | 2021-02-01 |",
"| B | 10 | 2021-02-01 |",
"| C | 10 | 2021-02-02 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
let ctx = SessionContext::new();
let table = table
.scan_cdf()
.with_starting_version(0)
.build(&ctx.state(), None)
.await
.expect("Failed to load CDF");
let mut batches = collect(table, ctx.task_ctx())
.await
.expect("Failed to collect batches");
let _ = arrow::util::pretty::print_batches(&batches);
let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(5)).collect();
assert_batches_sorted_eq! {[
"+----+-------+------------+--------------+-----------------+",
"| id | value | modified | _change_type | _commit_version |",
"+----+-------+------------+--------------+-----------------+",
"| D | 100 | 2021-02-02 | delete | 2 |",
"| A | 1 | 2021-02-01 | insert | 1 |",
"| B | 10 | 2021-02-01 | insert | 1 |",
"| C | 10 | 2021-02-02 | insert | 1 |",
"| D | 100 | 2021-02-02 | insert | 1 |",
"+----+-------+------------+--------------+-----------------+",
], &batches }
}
}