use crate::table::Table;
use crate::transaction::Transaction;
use anyhow::Result;
use arrow::record_batch::RecordBatch;
use datafusion::logical_expr::Expr;
use futures::stream::BoxStream;
#[derive(Debug, Clone)]
pub enum MergeAction {
Update(Vec<(String, Expr)>),
Delete,
Insert(Vec<(String, Expr)>),
}
#[derive(Debug, Clone)]
pub struct MergeClause {
pub condition: Option<Expr>,
pub action: MergeAction,
}
#[allow(unused)]
pub struct MergeBuilder {
table: Table,
source: BoxStream<'static, Result<RecordBatch>>,
on_condition: Expr,
matched_clauses: Vec<MergeClause>,
not_matched_clauses: Vec<MergeClause>,
}
impl MergeBuilder {
pub fn new(
table: Table,
source: BoxStream<'static, Result<RecordBatch>>,
on_condition: Expr,
) -> Self {
Self {
table,
source,
on_condition,
matched_clauses: Vec::new(),
not_matched_clauses: Vec::new(),
}
}
pub fn when_matched(mut self, condition: Option<Expr>, action: MergeAction) -> Self {
self.matched_clauses.push(MergeClause { condition, action });
self
}
pub fn when_not_matched(mut self, condition: Option<Expr>, action: MergeAction) -> Self {
self.not_matched_clauses
.push(MergeClause { condition, action });
self
}
pub async fn execute(self) -> Result<Transaction> {
use datafusion::prelude::*;
use futures::StreamExt;
let ctx = SessionContext::new();
let batches: Vec<Result<RecordBatch>> = self.source.collect().await;
let source_batches: Vec<RecordBatch> =
batches.into_iter().collect::<Result<Vec<RecordBatch>>>()?;
if source_batches.is_empty() {
return Ok(self.table.new_transaction());
}
let source_schema = source_batches[0].schema();
let source_provider = datafusion::datasource::MemTable::try_new(
source_schema,
vec![source_batches.clone()], )?;
ctx.register_table("source", std::sync::Arc::new(source_provider))?;
let source_df = ctx.table("source").await?;
let storage = self.table.storage.clone();
let reader = crate::reader::TableReader::new(storage.clone());
let snapshot = self
.table
.metadata
.current_snapshot()
.ok_or_else(|| anyhow::anyhow!("No snapshot"))?;
let (data_files, _) = snapshot.all_files(&storage).await?;
let mut target_batches = Vec::new();
for file in data_files {
let batches = reader.read_file(&file.file_path).await?;
target_batches.extend(batches);
}
let target_schema = if !target_batches.is_empty() {
target_batches[0].schema()
} else {
self.table.metadata.current_schema().to_arrow_schema_ref()
};
let target_provider =
datafusion::datasource::MemTable::try_new(target_schema, vec![target_batches])?;
ctx.register_table("target", std::sync::Arc::new(target_provider))?;
let target_df = ctx.table("target").await?;
let schema = self.table.metadata.current_schema();
let id_field = schema
.fields
.iter()
.find(|f| f.id == 1)
.ok_or_else(|| anyhow::anyhow!("PK not found"))?;
let id_col = &id_field.name;
let join_df = source_df.join(
target_df,
datafusion::logical_expr::JoinType::Inner,
&[id_col],
&[id_col],
None,
)?; let matched_ids_df = join_df.select(vec![col(id_col)])?;
let matched_batches = matched_ids_df.collect().await?;
let mut ids_to_delete = Vec::new();
for batch in &matched_batches {
if batch.num_columns() > 0 {
ids_to_delete.push(batch.column(0).clone());
}
}
let mut tx = self.table.new_transaction();
if !ids_to_delete.is_empty() {
let total_ids: Vec<&dyn arrow::array::Array> =
ids_to_delete.iter().map(|a| a.as_ref()).collect();
let combined_ids = arrow::compute::concat(&total_ids)?;
let del_schema = std::sync::Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new(
id_field.name.clone(),
id_field.field_type.to_arrow_datatype(),
false,
),
]));
let batch =
arrow::record_batch::RecordBatch::try_new(del_schema.clone(), vec![combined_ids])?;
let writer = crate::writer::TableWriter::new(
self.table.storage.clone(),
self.table.metadata.location.clone(),
del_schema,
);
let file_id = uuid::Uuid::new_v4().to_string();
let mut data_file = writer
.write_batch(&batch, &format!("delete-merge-{}", file_id))
.await?;
data_file.content = crate::manifest::FileContent::EqualityDeletes;
tx.add_file(data_file);
}
let writer = crate::writer::TableWriter::new(
self.table.storage.clone(),
self.table.metadata.location.clone(),
self.table.metadata.current_schema().to_arrow_schema_ref(),
);
for batch in source_batches {
let file_id = uuid::Uuid::new_v4().to_string();
let data_file = writer.write_batch(&batch, &file_id).await?;
tx.add_file(data_file);
}
Ok(tx)
}
}