use crate::DeltaResult;
use crate::delta_datafusion::logical::{LogicalPlanBuilderExt as _, LogicalPlanExt as _};
use crate::kernel::EagerSnapshot;
use crate::table::config::TablePropertiesExt as _;
use datafusion::common::ScalarValue;
use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
use datafusion::prelude::*;
pub const CDC_COLUMN_NAME: &str = "_change_type";
pub(crate) struct CDCTracker {
pre_dataframe: LogicalPlan,
post_dataframe: LogicalPlan,
}
impl CDCTracker {
pub(crate) fn new(pre_dataframe: LogicalPlan, post_dataframe: LogicalPlan) -> Self {
Self {
pre_dataframe,
post_dataframe,
}
}
pub(crate) fn collect(self) -> DeltaResult<LogicalPlan> {
let pre_df = self.pre_dataframe;
let post_df = self.post_dataframe;
let preimage = LogicalPlanBuilder::except(pre_df.clone(), post_df.clone(), true)?;
let postimage = LogicalPlanBuilder::except(post_df, pre_df, true)?;
let preimage = preimage.into_builder().with_column(
"_change_type",
lit(ScalarValue::Utf8(Some("update_preimage".to_string()))),
)?;
let postimage = postimage
.into_builder()
.with_column(
"_change_type",
lit(ScalarValue::Utf8(Some("update_postimage".to_string()))),
)?
.build()?;
let final_df = preimage.union(postimage)?.build()?;
Ok(final_df)
}
}
pub(crate) fn should_write_cdc(snapshot: &EagerSnapshot) -> DeltaResult<bool> {
if let Some(features) = &snapshot.protocol().writer_features() {
if snapshot.protocol().min_writer_version() == 7
&& !features.contains(&delta_kernel::table_features::TableFeature::ChangeDataFeed)
{
return Ok(false);
}
}
Ok(snapshot.table_properties().enable_change_data_feed())
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::{ArrayRef, Int32Array, StructArray};
use arrow::datatypes::{DataType, Field};
use arrow_array::RecordBatch;
use arrow_schema::Schema;
use datafusion::assert_batches_sorted_eq;
use datafusion::datasource::{MemTable, TableProvider};
use delta_kernel::table_features::TableFeature;
use super::*;
use crate::kernel::{Action, PrimitiveType};
use crate::kernel::{DataType as DeltaDataType, ProtocolInner};
use crate::{DeltaTable, TableProperty};
#[tokio::test]
async fn test_should_write_cdc_basic_table() {
let mut table = DeltaTable::new_in_memory()
.create()
.with_column(
"value",
DeltaDataType::Primitive(PrimitiveType::Integer),
true,
None,
)
.await
.expect("Failed to make a table");
table.load().await.expect("Failed to reload table");
let result =
should_write_cdc(table.snapshot().unwrap().snapshot()).expect("Failed to use table");
assert!(!result, "A default table should not create CDC files");
}
#[tokio::test]
async fn test_should_write_cdc_table_with_configuration() {
let actions = vec![Action::Protocol(ProtocolInner::new(1, 4).as_kernel())];
let mut table: DeltaTable = DeltaTable::new_in_memory()
.create()
.with_column(
"value",
DeltaDataType::Primitive(PrimitiveType::Integer),
true,
None,
)
.with_actions(actions)
.with_configuration_property(TableProperty::EnableChangeDataFeed, Some("true"))
.await
.expect("failed to make a version 4 table with EnableChangeDataFeed");
table.load().await.expect("Failed to reload table");
let result =
should_write_cdc(table.snapshot().unwrap().snapshot()).expect("Failed to use table");
assert!(
result,
"A table with the EnableChangeDataFeed should create CDC files"
);
}
#[tokio::test]
async fn test_should_write_cdc_v7_table_no_writer_feature() {
let actions = vec![Action::Protocol(ProtocolInner::new(1, 7).as_kernel())];
let mut table: DeltaTable = DeltaTable::new_in_memory()
.create()
.with_column(
"value",
DeltaDataType::Primitive(PrimitiveType::Integer),
true,
None,
)
.with_actions(actions)
.await
.expect("failed to make a version 4 table with EnableChangeDataFeed");
table.load().await.expect("Failed to reload table");
let result =
should_write_cdc(table.snapshot().unwrap().snapshot()).expect("Failed to use table");
assert!(
!result,
"A v7 table must not write CDC files unless the writer feature is set"
);
}
#[tokio::test]
async fn test_should_write_cdc_v7_table_with_writer_feature() {
let protocol = ProtocolInner::new(1, 7)
.append_writer_features(vec![TableFeature::ChangeDataFeed])
.as_kernel();
let actions = vec![Action::Protocol(protocol)];
let mut table: DeltaTable = DeltaTable::new_in_memory()
.create()
.with_column(
"value",
DeltaDataType::Primitive(PrimitiveType::Integer),
true,
None,
)
.with_actions(actions)
.with_configuration_property(TableProperty::EnableChangeDataFeed, Some("true"))
.await
.expect("failed to make a version 4 table with EnableChangeDataFeed");
table.load().await.expect("Failed to reload table");
let result =
should_write_cdc(table.snapshot().unwrap().snapshot()).expect("Failed to use table");
assert!(
result,
"A v7 table must not write CDC files unless the writer feature is set"
);
}
#[tokio::test]
async fn test_sanity_check() {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int32,
true,
)]));
let batch = RecordBatch::try_new(
Arc::clone(&schema.clone()),
vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))],
)
.unwrap();
let table_provider: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap());
let source_df = ctx.read_table(table_provider).unwrap();
let updated_batch = RecordBatch::try_new(
Arc::clone(&schema.clone()),
vec![Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)]))],
)
.unwrap();
let table_provider_updated: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(schema.clone(), vec![vec![updated_batch]]).unwrap());
let updated_df = ctx.read_table(table_provider_updated).unwrap();
let tracker = CDCTracker::new(
source_df.into_unoptimized_plan(),
updated_df.into_unoptimized_plan(),
);
match tracker.collect() {
Ok(plan) => {
let batches = ctx
.execute_logical_plan(plan)
.await
.unwrap()
.collect()
.await
.unwrap();
let _ = arrow::util::pretty::print_batches(&batches);
assert_eq!(batches.len(), 2);
assert_batches_sorted_eq! {[
"+-------+------------------+",
"| value | _change_type |",
"+-------+------------------+",
"| 2 | update_preimage |",
"| 12 | update_postimage |",
"+-------+------------------+",
], &batches }
}
Err(err) => {
println!("err: {err:#?}");
panic!("Should have never reached this assertion");
}
}
}
#[tokio::test]
async fn test_sanity_check_with_pure_df() {
let nested_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("lat", DataType::Int32, true),
Field::new("long", DataType::Int32, true),
]));
let schema = Arc::new(Schema::new(vec![
Field::new("value", DataType::Int32, true),
Field::new(
"nested",
DataType::Struct(nested_schema.fields.clone()),
true,
),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();
let updated_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();
let _ = arrow::util::pretty::print_batches(&[batch.clone()]);
let _ = arrow::util::pretty::print_batches(&[updated_batch.clone()]);
let ctx = SessionContext::new();
let before = ctx.read_batch(batch).expect("Failed to make DataFrame");
let after = ctx
.read_batch(updated_batch)
.expect("Failed to make DataFrame");
let diff = before
.except(after)
.expect("Failed to except")
.collect()
.await
.expect("Failed to diff");
assert_eq!(diff.len(), 1);
}
#[tokio::test]
async fn test_sanity_check_with_struct() {
let ctx = SessionContext::new();
let nested_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("lat", DataType::Int32, true),
Field::new("long", DataType::Int32, true),
]));
let schema = Arc::new(Schema::new(vec![
Field::new("value", DataType::Int32, true),
Field::new(
"nested",
DataType::Struct(nested_schema.fields.clone()),
true,
),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema.clone()),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();
let table_provider: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap());
let source_df = ctx.read_table(table_provider).unwrap();
let updated_batch = RecordBatch::try_new(
Arc::clone(&schema.clone()),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();
let table_provider_updated: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(schema.clone(), vec![vec![updated_batch]]).unwrap());
let updated_df = ctx.read_table(table_provider_updated).unwrap();
let tracker = CDCTracker::new(
source_df.into_unoptimized_plan(),
updated_df.into_unoptimized_plan(),
);
match tracker.collect() {
Ok(plan) => {
let batches = ctx
.execute_logical_plan(plan)
.await
.unwrap()
.collect()
.await
.unwrap();
let _ = arrow::util::pretty::print_batches(&batches);
assert_eq!(batches.len(), 2);
assert_batches_sorted_eq! {[
"+-------+--------------------------+------------------+",
"| value | nested | _change_type |",
"+-------+--------------------------+------------------+",
"| 12 | {id: 2, lat: 2, long: 2} | update_postimage |",
"| 2 | {id: 2, lat: 2, long: 2} | update_preimage |",
"+-------+--------------------------+------------------+",
], &batches }
}
Err(err) => {
println!("err: {err:#?}");
panic!("Should have never reached this assertion");
}
}
}
}