use std::{collections::HashMap, sync::Arc, time::Instant};
use async_trait::async_trait;
use datafusion::error::Result as DataFusionResult;
use datafusion::{
catalog::Session,
common::{Column, ScalarValue, ToDFSchema as _, exec_datafusion_err},
error::DataFusionError,
execution::context::SessionState,
logical_expr::{
Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, case, col, lit, when,
},
physical_plan::{ExecutionPlan, metrics::MetricBuilder},
physical_planner::{ExtensionPlanner, PhysicalPlanner},
prelude::Expr,
};
use futures::{StreamExt as _, TryStreamExt as _, future::BoxFuture, stream};
use itertools::Itertools as _;
use parquet::file::properties::WriterProperties;
use serde::Serialize;
use tracing::log::*;
use uuid::Uuid;
use super::write::WriterStatsConfig;
use super::{
CustomExecuteHandler, Operation,
write::execution::{write_execution_plan, write_execution_plan_cdc},
};
use crate::delta_datafusion::{
DeltaScanConfig, Expression, scan_files_where_matches, update_datafusion_session,
};
use crate::kernel::resolve_snapshot;
use crate::logstore::LogStoreRef;
use crate::operations::cdc::*;
use crate::protocol::DeltaOperation;
use crate::table::state::DeltaTableState;
use crate::{DeltaResult, DeltaTable, DeltaTableError};
use crate::{
delta_datafusion::{
DeltaColumn, DeltaSessionExt, SessionFallbackPolicy, SessionResolveContext, create_session,
expr::fmt_expr_to_sql,
logical::{LogicalPlanBuilderExt as _, LogicalPlanExt as _, MetricObserver},
physical::{MetricObserverExec, find_metric_node, get_metric},
resolve_session_state,
},
kernel::{
Action, EagerSnapshot,
transaction::{CommitBuilder, CommitProperties, PROTOCOL},
},
table::config::TablePropertiesExt,
};
pub(crate) const UPDATE_PREDICATE_COLNAME: &str = "__delta_rs_update_predicate";
#[cfg(test)]
mod tests;
const UPDATE_COUNT_ID: &str = "update_source_count";
const UPDATE_ROW_COUNT: &str = "num_updated_rows";
const COPIED_ROW_COUNT: &str = "num_copied_rows";
pub struct UpdateBuilder {
predicate: Option<Expression>,
updates: HashMap<Column, Expression>,
snapshot: Option<EagerSnapshot>,
log_store: LogStoreRef,
session: Option<Arc<dyn Session>>,
session_fallback_policy: SessionFallbackPolicy,
writer_properties: Option<WriterProperties>,
commit_properties: CommitProperties,
safe_cast: bool,
custom_execute_handler: Option<Arc<dyn CustomExecuteHandler>>,
}
#[derive(Default, Serialize, Debug)]
pub struct UpdateMetrics {
pub num_added_files: usize,
pub num_removed_files: usize,
pub num_updated_rows: usize,
pub num_copied_rows: usize,
pub execution_time_ms: u64,
pub scan_time_ms: u64,
}
impl super::Operation for UpdateBuilder {
fn log_store(&self) -> &LogStoreRef {
&self.log_store
}
fn get_custom_execute_handler(&self) -> Option<Arc<dyn CustomExecuteHandler>> {
self.custom_execute_handler.clone()
}
}
impl UpdateBuilder {
pub(crate) fn new(log_store: LogStoreRef, snapshot: Option<EagerSnapshot>) -> Self {
Self {
predicate: None,
updates: HashMap::new(),
snapshot,
log_store,
session: None,
session_fallback_policy: SessionFallbackPolicy::default(),
writer_properties: None,
commit_properties: CommitProperties::default(),
safe_cast: false,
custom_execute_handler: None,
}
}
pub fn with_predicate<E: Into<Expression>>(mut self, predicate: E) -> Self {
self.predicate = Some(predicate.into());
self
}
pub fn with_update<S: Into<DeltaColumn>, E: Into<Expression>>(
mut self,
column: S,
expression: E,
) -> Self {
self.updates.insert(column.into().into(), expression.into());
self
}
pub fn with_session_state(mut self, session: Arc<dyn Session>) -> Self {
self.session = Some(session);
self
}
pub fn with_session_fallback_policy(mut self, policy: SessionFallbackPolicy) -> Self {
self.session_fallback_policy = policy;
self
}
pub fn with_commit_properties(mut self, commit_properties: CommitProperties) -> Self {
self.commit_properties = commit_properties;
self
}
pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self {
self.writer_properties = Some(writer_properties);
self
}
pub fn with_safe_cast(mut self, safe_cast: bool) -> Self {
self.safe_cast = safe_cast;
self
}
pub fn with_custom_execute_handler(mut self, handler: Arc<dyn CustomExecuteHandler>) -> Self {
self.custom_execute_handler = Some(handler);
self
}
}
#[derive(Clone, Debug)]
pub(crate) struct UpdateMetricExtensionPlanner {}
impl UpdateMetricExtensionPlanner {
pub fn new() -> Arc<Self> {
Arc::new(Self {})
}
}
#[async_trait]
impl ExtensionPlanner for UpdateMetricExtensionPlanner {
async fn plan_extension(
&self,
_planner: &dyn PhysicalPlanner,
node: &dyn UserDefinedLogicalNode,
_logical_inputs: &[&LogicalPlan],
physical_inputs: &[Arc<dyn ExecutionPlan>],
_session_state: &SessionState,
) -> DataFusionResult<Option<Arc<dyn ExecutionPlan>>> {
if let Some(metric_observer) = node.as_any().downcast_ref::<MetricObserver>()
&& metric_observer.id.eq(UPDATE_COUNT_ID)
{
return Ok(Some(MetricObserverExec::try_new(
UPDATE_COUNT_ID.into(),
physical_inputs,
|batch, metrics| {
let array = batch.column_by_name(UPDATE_PREDICATE_COLNAME).unwrap();
let copied_rows = array.null_count();
let num_updated = array.len() - copied_rows;
MetricBuilder::new(metrics)
.global_counter(UPDATE_ROW_COUNT)
.add(num_updated);
MetricBuilder::new(metrics)
.global_counter(COPIED_ROW_COUNT)
.add(copied_rows);
},
)?));
}
Ok(None)
}
}
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(
skip_all,
fields(
operation = "update",
version = snapshot.version(),
table_uri = %log_store.root_url(),
)
)]
async fn execute(
predicate: Expr,
updates: HashMap<Column, Expression>,
log_store: LogStoreRef,
snapshot: &EagerSnapshot,
session: &dyn Session,
writer_properties: Option<WriterProperties>,
operation_id: Uuid,
) -> DeltaResult<(Vec<Action>, UpdateMetrics)> {
let exec_start = Instant::now();
let mut metrics = UpdateMetrics::default();
let scan_config = DeltaScanConfig::new_from_session(session);
let schema = scan_config
.table_schema(snapshot.table_configuration())?
.to_dfschema_ref()?;
let updates: HashMap<_, _> = updates
.into_iter()
.map(|(key, expr)| expr.resolve(session, schema.clone()).map(|e| (key.name, e)))
.try_collect()?;
let current_metadata = snapshot.metadata();
let table_partition_cols = current_metadata.partition_columns().to_vec();
let scan_start = Instant::now();
let maybe_scan_plan =
scan_files_where_matches(session, snapshot, log_store.clone(), predicate).await?;
metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_millis() as u64;
let Some(files_scan) = maybe_scan_plan else {
return Ok((vec![], metrics));
};
let predicate_null =
when(files_scan.predicate.clone(), lit(true)).otherwise(lit(ScalarValue::Boolean(None)))?;
let input = files_scan
.scan()
.clone()
.into_builder()
.with_column(UPDATE_PREDICATE_COLNAME, predicate_null)?
.build()?;
let plan_with_metrics = LogicalPlan::Extension(Extension {
node: Arc::new(MetricObserver {
id: UPDATE_COUNT_ID.into(),
input,
enable_pushdown: false,
}),
});
let expressions: Vec<_> = plan_with_metrics
.schema()
.fields()
.into_iter()
.map(|field| {
let expr = match updates.get(field.name()) {
Some(expr) => case(col(UPDATE_PREDICATE_COLNAME))
.when(lit(true), expr.to_owned())
.otherwise(col(Column::from_name(field.name())))?
.alias(field.name()),
None => col(Column::from_name(field.name())),
};
Ok::<_, DataFusionError>(expr)
})
.try_collect()?;
let plan_updated = LogicalPlanBuilder::new(plan_with_metrics)
.project(expressions.clone())?
.drop_columns([UPDATE_PREDICATE_COLNAME])?
.build()?;
let physical_plan = session.create_physical_plan(&plan_updated).await?;
let tracker = CDCTracker::new(files_scan.scan().clone(), plan_updated);
let writer_stats_config = WriterStatsConfig::from_config(snapshot.table_configuration());
let mut actions = write_execution_plan(
Some(snapshot),
session,
physical_plan.clone(),
table_partition_cols.to_vec(),
log_store.object_store(Some(operation_id)).clone(),
Some(snapshot.table_properties().target_file_size()),
None,
writer_properties.clone(),
writer_stats_config.clone(),
)
.await?;
let err = || DeltaTableError::Generic("Unable to locate expected metric node".into());
let update_count = find_metric_node(UPDATE_COUNT_ID, &physical_plan).ok_or_else(err)?;
let update_count_metrics = update_count.metrics().unwrap();
metrics.num_updated_rows = get_metric(&update_count_metrics, UPDATE_ROW_COUNT);
metrics.num_copied_rows = get_metric(&update_count_metrics, COPIED_ROW_COUNT);
let root_url = Arc::new(snapshot.table_configuration().table_root().clone());
let removes: Vec<_> = snapshot
.file_views(log_store.as_ref(), Some(files_scan.delta_predicate.clone()))
.zip(stream::iter(std::iter::repeat((
root_url,
Arc::new(files_scan.files_set()),
))))
.map(|(f, u)| f.map(|f| (f, u)))
.try_filter_map(|(f, (root, valid))| async move {
let url = root
.clone()
.join(f.path_raw())
.map_err(|e| exec_datafusion_err!("{e}"))?;
let is_valid = valid.contains(url.as_ref());
Ok(is_valid.then(|| Action::Remove(f.remove_action(true))))
})
.try_collect()
.await?;
metrics.num_added_files = actions.len();
metrics.num_removed_files = removes.len();
actions.extend(removes);
metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_millis() as u64;
if let Ok(true) = should_write_cdc(snapshot) {
match tracker.collect() {
Ok(cdc_plan) => {
let cdc_exec = session.create_physical_plan(&cdc_plan).await?;
let cdc_actions = write_execution_plan_cdc(
Some(snapshot),
session,
cdc_exec,
table_partition_cols.to_vec(),
log_store.object_store(Some(operation_id)),
Some(snapshot.table_properties().target_file_size()),
None,
writer_properties,
writer_stats_config,
)
.await?;
actions.extend(cdc_actions);
}
Err(err) => {
error!("Failed to collect CDC batches: {err:#?}");
}
};
}
Ok((actions, metrics))
}
impl std::future::IntoFuture for UpdateBuilder {
type Output = DeltaResult<(DeltaTable, UpdateMetrics)>;
type IntoFuture = BoxFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
let mut this = self;
Box::pin(async move {
let snapshot =
resolve_snapshot(&this.log_store, this.snapshot.clone(), true, None).await?;
PROTOCOL.check_append_only(&snapshot)?;
PROTOCOL.can_write_to(&snapshot)?;
let operation_id = this.get_operation_id();
this.pre_execute(operation_id).await?;
let (state, _) = resolve_session_state(
this.session.as_deref(),
this.session_fallback_policy,
|| create_session().state(),
SessionResolveContext {
operation: "update",
table_uri: Some(this.log_store.root_url()),
cdc: false,
},
)?;
update_datafusion_session(&state, &this.log_store, Some(operation_id))?;
state.ensure_log_store_registered(this.log_store.as_ref())?;
if this.updates.is_empty() {
return Ok((
DeltaTable::new_with_state(this.log_store, DeltaTableState::new(snapshot)),
UpdateMetrics::default(),
));
}
let predicate = this
.predicate
.map(|p| {
let scan_config = DeltaScanConfig::new_from_session(&state);
let predicate_schema = scan_config
.table_schema(snapshot.table_configuration())?
.to_dfschema_ref()?;
p.resolve(&state, predicate_schema)
})
.transpose()?;
let predicate = predicate.unwrap_or(lit(true));
let operation = DeltaOperation::Update {
predicate: Some(fmt_expr_to_sql(&predicate)?),
};
let (actions, metrics) = execute(
predicate,
this.updates,
this.log_store.clone(),
&snapshot,
&state,
this.writer_properties,
operation_id,
)
.await?;
if actions.is_empty() {
return Ok((
DeltaTable::new_with_state(this.log_store, DeltaTableState::new(snapshot)),
metrics,
));
}
let mut props = this.commit_properties;
props
.app_metadata
.insert("readVersion".to_owned(), snapshot.version().into());
props.app_metadata.insert(
"operationMetrics".to_owned(),
serde_json::to_value(&metrics)?,
);
let handle = this.custom_execute_handler.take();
let snapshot = CommitBuilder::from(props)
.with_actions(actions)
.with_operation_id(operation_id)
.with_post_commit_hook_handler(handle)
.build(Some(&snapshot), this.log_store.clone(), operation)
.await?
.snapshot()
.snapshot;
Ok((
DeltaTable::new_with_state(this.log_store, DeltaTableState::new(snapshot)),
metrics,
))
})
}
}