const MERGE_ACTION_COLUMN: &str = "__action";
pub mod inserted_rows;
use assign_action::merge_insert_action;
use inserted_rows::KeyExistenceFilter;
use super::retry::{RetryConfig, RetryExecutor, execute_with_retry};
use super::{CommitBuilder, WriteParams, write_fragments_internal};
use crate::dataset::rowids::get_row_id_index;
use crate::dataset::transaction::UpdateMode::{RewriteColumns, RewriteRows};
use crate::dataset::utils::CapturedRowIds;
use crate::{
Dataset,
datafusion::dataframe::SessionContextExt,
dataset::{
fragment::{FileFragment, FragReadConfig},
transaction::{Operation, Transaction},
write::{merge_insert::logical_plan::MergeInsertPlanner, open_writer},
},
index::DatasetIndexInternalExt,
io::exec::{
AddRowAddrExec, Planner, TakeExec, project, scalar_index::MapIndexExec, utils::ReplayExec,
},
};
use arrow_array::{
BooleanArray, RecordBatch, RecordBatchIterator, StructArray, UInt32Array, UInt64Array,
cast::AsArray, types::UInt64Type,
};
use arrow_schema::{DataType, Field, Schema};
use arrow_select::take::take_record_batch;
use datafusion::common::NullEquality;
use datafusion::error::DataFusionError;
use datafusion::{
execution::{
context::{SessionConfig, SessionContext},
memory_pool::MemoryConsumer,
},
logical_expr::{self, Expr, Extension, JoinType, LogicalPlan},
physical_plan::{
ColumnarValue, ExecutionPlan, PhysicalExpr, SendableRecordBatchStream,
display::DisplayableExecutionPlan,
joins::{HashJoinExec, PartitionMode},
projection::ProjectionExec,
repartition::RepartitionExec,
stream::RecordBatchStreamAdapter,
union::UnionExec,
},
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
prelude::DataFrame,
scalar::ScalarValue,
};
use datafusion_physical_expr::expressions::Column;
use futures::{
Stream, StreamExt, TryStreamExt,
stream::{self},
};
use lance_arrow::{RecordBatchExt, SchemaExt, interleave_batches};
use lance_core::datatypes::NullabilityComparison;
use lance_core::utils::address::RowAddress;
use lance_core::{
Error, ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD, Result,
datatypes::{OnMissing, OnTypeMismatch, SchemaCompareOptions},
error::{InvalidInputSnafu, box_error},
utils::{futures::Capacity, mask::RowAddrTreeMap, tokio::get_num_compute_intensive_cpus},
};
use lance_datafusion::{
chunker::chunk_stream,
dataframe::DataFrameExt,
exec::{LanceExecutionOptions, analyze_plan, get_session_context},
utils::reader_to_stream,
};
use lance_datafusion::{
exec::{OneShotExec, execute_plan},
utils::StreamingWriteSource,
};
use lance_file::version::LanceFileVersion;
use lance_index::mem_wal::MergedGeneration;
use lance_index::{DatasetIndexExt, IndexCriteria};
use lance_table::format::{Fragment, IndexMetadata, RowIdMeta};
use log::info;
use roaring::RoaringTreemap;
use snafu::ResultExt;
use std::{
collections::{BTreeMap, HashSet},
sync::{
Arc, Mutex,
atomic::{AtomicU32, Ordering},
},
time::Duration,
};
use tokio::task::JoinSet;
use tracing::error;
mod assign_action;
mod exec;
mod logical_plan;
fn combined_schema(schema: &Schema) -> Schema {
let target = Field::new("target", DataType::Struct(schema.fields.clone()), false);
let source = Field::new("source", DataType::Struct(schema.fields.clone()), false);
Schema::new(vec![source, target])
}
fn unzip_batch(batch: &RecordBatch, schema: &Schema) -> RecordBatch {
let num_fields = batch.num_columns();
debug_assert_eq!(num_fields % 2, 1);
let half_num_fields = num_fields / 2;
let row_id_col = num_fields - 1;
let source_arrays = batch.columns()[0..half_num_fields].to_vec();
let source = StructArray::new(schema.fields.clone(), source_arrays, None);
let target_arrays = batch.columns()[half_num_fields..row_id_col].to_vec();
let target = StructArray::new(schema.fields.clone(), target_arrays, None);
let combined_schema = combined_schema(schema);
RecordBatch::try_new(
Arc::new(combined_schema),
vec![Arc::new(source), Arc::new(target)],
)
.unwrap()
}
pub fn format_key_values_on_columns(
batch: &RecordBatch,
row_idx: usize,
on_columns: &[String],
) -> String {
let mut on_values = Vec::new();
for col_name in on_columns {
if let Some(col_idx) = batch.schema().column_with_name(col_name) {
let column = batch.column(col_idx.0);
let value_str = if column.is_null(row_idx) {
"NULL".to_string()
} else {
match ScalarValue::try_from_array(column, row_idx) {
Ok(scalar_value) => match &scalar_value {
ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
format!("\"{}\"", s)
}
_ => scalar_value.to_string(),
},
Err(_) => format!("<{:?}>", column.data_type()),
}
};
on_values.push(format!("{} = {}", col_name, value_str));
}
}
if on_values.is_empty() {
"<unable to extract on column values>".to_string()
} else {
on_values.join(", ")
}
}
pub fn create_duplicate_row_error(
batch: &RecordBatch,
row_idx: usize,
on_columns: &[String],
) -> DataFusionError {
DataFusionError::External(Box::new(Error::invalid_input(format!(
"Ambiguous merge inserts are prohibited: multiple source rows match the same target row on ({}). \
Please ensure each target row is matched by at most one source row.",
format_key_values_on_columns(batch, row_idx, on_columns)
))))
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum WhenNotMatchedBySource {
Keep,
Delete,
DeleteIf(Expr),
}
impl WhenNotMatchedBySource {
pub fn delete_if(dataset: &Dataset, expr: &str) -> Result<Self> {
let planner = Planner::new(Arc::new(dataset.schema().into()));
let expr = planner
.parse_filter(expr)
.map_err(box_error)
.context(InvalidInputSnafu {})?;
let expr = planner
.optimize_expr(expr)
.map_err(box_error)
.context(InvalidInputSnafu {})?;
Ok(Self::DeleteIf(expr))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum WhenMatched {
UpdateAll,
DoNothing,
UpdateIf(String),
Fail,
Delete,
}
impl WhenMatched {
pub fn update_if(_dataset: &Dataset, expr: &str) -> Result<Self> {
Ok(Self::UpdateIf(expr.to_string()))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum WhenNotMatched {
InsertAll,
DoNothing,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
pub enum SourceDedupeBehavior {
#[default]
Fail,
FirstSeen,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
struct MergeInsertParams {
on: Vec<String>,
when_matched: WhenMatched,
insert_not_matched: bool,
delete_not_matched_by_source: WhenNotMatchedBySource,
conflict_retries: u32,
retry_timeout: Duration,
merged_generations: Vec<MergedGeneration>,
skip_auto_cleanup: bool,
use_index: bool,
source_dedupe_behavior: SourceDedupeBehavior,
commit_retries: Option<u32>,
}
#[derive(Clone)]
pub struct MergeInsertJob {
dataset: Arc<Dataset>,
params: MergeInsertParams,
}
#[derive(Debug, Clone)]
pub struct MergeInsertBuilder {
dataset: Arc<Dataset>,
params: MergeInsertParams,
}
impl MergeInsertBuilder {
pub fn try_new(dataset: Arc<Dataset>, on: Vec<String>) -> Result<Self> {
let resolved_on = if on.is_empty() {
let schema = dataset.schema();
let pk_fields = schema.unenforced_primary_key();
if pk_fields.is_empty() {
return Err(Error::invalid_input(
"A merge insert operation requires join keys: specify `on` columns explicitly or configure a primary key in the dataset schema",
));
}
pk_fields
.iter()
.map(|field| schema.field_path(field.id))
.collect::<Result<Vec<_>>>()?
} else {
on.iter()
.map(|col| {
dataset
.schema()
.field_case_insensitive(col)
.map(|f| f.name.clone())
.ok_or_else(|| {
Error::invalid_input(format!(
"Merge insert key column '{}' does not exist in schema",
col
))
})
})
.collect::<Result<Vec<_>>>()?
};
Ok(Self {
dataset,
params: MergeInsertParams {
on: resolved_on,
when_matched: WhenMatched::DoNothing,
insert_not_matched: true,
delete_not_matched_by_source: WhenNotMatchedBySource::Keep,
conflict_retries: 10,
retry_timeout: Duration::from_secs(30),
merged_generations: Vec::new(),
skip_auto_cleanup: false,
use_index: true,
source_dedupe_behavior: SourceDedupeBehavior::Fail,
commit_retries: None,
},
})
}
pub fn when_matched(&mut self, behavior: WhenMatched) -> &mut Self {
self.params.when_matched = behavior;
self
}
pub fn when_not_matched(&mut self, behavior: WhenNotMatched) -> &mut Self {
self.params.insert_not_matched = match behavior {
WhenNotMatched::DoNothing => false,
WhenNotMatched::InsertAll => true,
};
self
}
pub fn when_not_matched_by_source(&mut self, behavior: WhenNotMatchedBySource) -> &mut Self {
self.params.delete_not_matched_by_source = behavior;
self
}
pub fn conflict_retries(&mut self, retries: u32) -> &mut Self {
self.params.conflict_retries = retries;
self
}
pub fn retry_timeout(&mut self, timeout: Duration) -> &mut Self {
self.params.retry_timeout = timeout;
self
}
pub fn skip_auto_cleanup(&mut self, skip: bool) -> &mut Self {
self.params.skip_auto_cleanup = skip;
self
}
pub fn use_index(&mut self, use_index: bool) -> &mut Self {
self.params.use_index = use_index;
self
}
pub fn source_dedupe_behavior(&mut self, behavior: SourceDedupeBehavior) -> &mut Self {
self.params.source_dedupe_behavior = behavior;
self
}
pub fn mark_generations_as_merged(&mut self, generations: Vec<MergedGeneration>) -> &mut Self {
self.params.merged_generations.extend(generations);
self
}
pub fn commit_retries(&mut self, retries: u32) -> &mut Self {
self.params.commit_retries = Some(retries);
self
}
pub fn try_build(&mut self) -> Result<MergeInsertJob> {
if !self.params.insert_not_matched
&& self.params.when_matched == WhenMatched::DoNothing
&& self.params.delete_not_matched_by_source == WhenNotMatchedBySource::Keep
{
return Err(Error::invalid_input(
"The merge insert job is not configured to change the data in any way",
));
}
Ok(MergeInsertJob {
dataset: self.dataset.clone(),
params: self.params.clone(),
})
}
}
enum SchemaComparison {
FullCompatible,
Subschema,
}
impl MergeInsertJob {
pub async fn execute_reader(
self,
source: impl StreamingWriteSource,
) -> Result<(Arc<Dataset>, MergeStats)> {
let stream = source.into_stream();
self.execute(stream).await
}
fn check_compatible_schema(&self, schema: &Schema) -> Result<SchemaComparison> {
let lance_schema: lance_core::datatypes::Schema = schema.try_into()?;
let target_schema = self.dataset.schema();
let mut options = SchemaCompareOptions {
compare_dictionary: self.dataset.is_legacy_storage(),
compare_nullability: NullabilityComparison::Ignore,
..Default::default()
};
if lance_schema
.check_compatible(target_schema, &options)
.is_ok()
{
return Ok(SchemaComparison::FullCompatible);
}
options.allow_subschema = true;
options.ignore_field_order = true;
lance_schema
.check_compatible(target_schema, &options)
.map(|_| SchemaComparison::Subschema)
}
async fn join_key_as_scalar_index(&self) -> Result<Option<IndexMetadata>> {
if self.params.on.len() != 1 {
Ok(None)
} else {
let col = &self.params.on[0];
self.dataset
.load_scalar_index(
IndexCriteria::default()
.for_column(col)
.supports_exact_equality(),
)
.await
}
}
async fn create_indexed_scan_joined_stream(
&self,
source: SendableRecordBatchStream,
index: IndexMetadata,
) -> Result<SendableRecordBatchStream> {
let schema = source.schema();
let add_row_addr = match self.check_compatible_schema(&schema)? {
SchemaComparison::FullCompatible => false,
SchemaComparison::Subschema => true,
};
let input = Arc::new(OneShotExec::new(source));
let shared_input = Arc::new(ReplayExec::new(Capacity::Unbounded, input));
let field = schema.field_with_name(&self.params.on[0])?;
let index_mapper_input = Arc::new(project(
shared_input.clone(),
&Schema::new(vec![field.clone()]),
)?);
let index_column = self.params.on[0].clone();
let mut index_mapper: Arc<dyn ExecutionPlan> = Arc::new(MapIndexExec::new(
self.dataset.clone(),
index_column.clone(),
index.name.clone(),
index_mapper_input,
));
if add_row_addr {
let pos = index_mapper.schema().fields().len(); index_mapper = Arc::new(AddRowAddrExec::try_new(
index_mapper,
self.dataset.clone(),
pos,
)?);
}
let projection = self
.dataset
.empty_projection()
.union_arrow_schema(schema.as_ref(), OnMissing::Error)?;
let mut target =
Arc::new(TakeExec::try_new(self.dataset.clone(), index_mapper, projection)?.unwrap())
as Arc<dyn ExecutionPlan>;
let schema = target.schema();
let mut columns = schema
.fields()
.iter()
.filter(|f| f.name() != ROW_ID && f.name() != ROW_ADDR)
.cloned()
.collect::<Vec<_>>();
columns.push(Arc::new(ROW_ID_FIELD.clone()));
if add_row_addr {
columns.push(Arc::new(ROW_ADDR_FIELD.clone()));
}
target = Arc::new(project(target, &Schema::new(columns))?);
let column_names = schema
.field_names()
.into_iter()
.filter(|name| name.as_str() != ROW_ID && name.as_str() != ROW_ADDR)
.collect::<Vec<_>>();
let unindexed_fragments = self.dataset.unindexed_fragments(&index.name).await?;
if !unindexed_fragments.is_empty() {
let mut builder = self.dataset.scan();
if add_row_addr {
builder.with_row_address();
}
let unindexed_data = builder
.with_row_id()
.with_fragments(unindexed_fragments)
.project(&column_names)
.unwrap()
.create_plan()
.await?;
let unioned = UnionExec::try_new(vec![target, unindexed_data])?;
target = Arc::new(RepartitionExec::try_new(
unioned,
datafusion::physical_plan::Partitioning::RoundRobinBatch(1),
)?);
}
target = Self::prefix_columns_phys(target, "target_");
let source_key = Column::new_with_schema(&index_column, shared_input.schema().as_ref())?;
let target_key = Column::new_with_schema(
&format!("target_{}", index_column),
target.schema().as_ref(),
)?;
let joined = Arc::new(
HashJoinExec::try_new(
shared_input,
target,
vec![(Arc::new(source_key), Arc::new(target_key))],
None,
&JoinType::Full,
None,
PartitionMode::CollectLeft,
NullEquality::NullEqualsNull,
)
.unwrap(),
);
execute_plan(
joined,
LanceExecutionOptions {
use_spilling: true,
..Default::default()
},
)
}
fn prefix_columns(df: DataFrame, prefix: &str) -> DataFrame {
let schema = df.schema();
let columns = schema
.fields()
.iter()
.map(|f| {
logical_expr::col(format!("\"{}\"", f.name())).alias(format!(
"{}{}",
prefix,
f.name()
))
})
.collect::<Vec<_>>();
df.select(columns).unwrap()
}
fn prefix_columns_phys(inp: Arc<dyn ExecutionPlan>, prefix: &str) -> Arc<dyn ExecutionPlan> {
let schema = inp.schema();
let exprs = schema
.fields()
.iter()
.enumerate()
.map(|(idx, f)| {
let col = Arc::new(Column::new(f.name(), idx)) as Arc<dyn PhysicalExpr>;
let new_name = format!("{}{}", prefix, f.name());
(col, new_name)
})
.collect::<Vec<_>>();
Arc::new(ProjectionExec::try_new(exprs, inp).unwrap())
}
async fn create_full_table_joined_stream(
&self,
source: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let session_config = SessionConfig::default().with_target_partitions(1);
let session_ctx = SessionContext::new_with_config(session_config);
let schema = source.schema();
let new_data = session_ctx.read_one_shot(source)?;
let join_cols = self
.params
.on .iter()
.map(|c| c.as_str())
.collect::<Vec<_>>(); let target_cols = self
.params
.on
.iter()
.map(|c| format!("target_{}", c))
.collect::<Vec<_>>();
let target_cols = target_cols.iter().map(|s| s.as_str()).collect::<Vec<_>>();
match self.check_compatible_schema(&schema)? {
SchemaComparison::FullCompatible => {
let existing = session_ctx.read_lance(self.dataset.clone(), true, false)?;
let existing = Self::prefix_columns(existing, "target_");
let joined =
new_data.join(existing, JoinType::Full, &join_cols, &target_cols, None)?; Ok(joined.execute_stream().await?)
}
SchemaComparison::Subschema => {
let existing = session_ctx.read_lance(self.dataset.clone(), true, true)?;
let columns = schema
.field_names()
.iter()
.map(|s| s.as_str())
.chain([ROW_ID, ROW_ADDR])
.collect::<Vec<_>>();
let projected = existing.select_columns(&columns)?;
let projected = Self::prefix_columns(projected, "target_");
let join_type = if self.params.insert_not_matched {
JoinType::Left
} else {
JoinType::Inner
};
let joined = new_data.join(projected, join_type, &join_cols, &target_cols, None)?;
Ok(joined.execute_stream().await?)
}
}
}
async fn create_joined_stream(
&self,
source: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let can_use_scalar_index = matches!(
self.params.delete_not_matched_by_source, WhenNotMatchedBySource::Keep
) && self.params.use_index;
if can_use_scalar_index {
if let Some(index) = self.join_key_as_scalar_index().await? {
self.create_indexed_scan_joined_stream(source, index).await
} else {
self.create_full_table_joined_stream(source).await
}
} else {
info!(
"The merge insert operation is configured to delete rows from the target table, this requires a potentially costly full table scan"
);
self.create_full_table_joined_stream(source).await
}
}
async fn update_fragments(
dataset: Arc<Dataset>,
source: SendableRecordBatchStream,
current_version: u64,
) -> Result<(Vec<Fragment>, Vec<Fragment>, Vec<u32>)> {
use datafusion::logical_expr::{col, lit};
let session_ctx = get_session_context(&LanceExecutionOptions {
use_spilling: true,
target_partition: Some(get_num_compute_intensive_cpus().min(8)),
..Default::default()
});
let mut group_stream = session_ctx
.read_one_shot(source)?
.with_column("_fragment_id", col(ROW_ADDR) >> lit(32))?
.sort(vec![col(ROW_ADDR).sort(true, true)])?
.group_by_stream(&["_fragment_id"])
.await?;
let updated_fragments = Arc::new(Mutex::new(Vec::new()));
let new_fragments = Arc::new(Mutex::new(Vec::new()));
let mut tasks = JoinSet::new();
let task_limit = dataset.object_store().io_parallelism();
let mut reservation =
MemoryConsumer::new("MergeInsert").register(session_ctx.task_ctx().memory_pool());
while let Some((frag_id, batches)) = group_stream.next().await.transpose()? {
async fn handle_fragment(
dataset: Arc<Dataset>,
fragment: FileFragment,
mut metadata: Fragment,
mut batches: Vec<RecordBatch>,
updated_fragments: Arc<Mutex<Vec<Fragment>>>,
reservation_size: usize,
current_version: u64,
) -> Result<usize> {
let write_schema = batches[0]
.schema()
.as_ref()
.without_column(ROW_ADDR)
.without_column(ROW_ID);
let write_schema = dataset.schema().project_by_schema(
&write_schema,
OnMissing::Error,
OnTypeMismatch::Error,
)?;
let updated_rows: usize = batches.iter().map(|batch| batch.num_rows()).sum();
if Some(updated_rows) == metadata.physical_rows {
let data_storage_version = dataset
.manifest()
.data_storage_format
.lance_file_version()?;
let mut writer = open_writer(
&dataset.object_store,
&write_schema,
&dataset.base,
data_storage_version,
)
.await?;
batches
.iter_mut()
.try_for_each(|batch| match batch.drop_column(ROW_ADDR) {
Ok(b) => {
*batch = b;
Ok(())
}
Err(e) => Err(e),
})?;
if data_storage_version == LanceFileVersion::Legacy {
let reader = fragment
.open(
dataset.schema(),
FragReadConfig::default().with_row_address(true),
)
.await?;
let batch_size = reader.legacy_num_rows_in_batch(0).unwrap();
let stream = stream::iter(batches.into_iter().map(Ok));
let stream = Box::pin(RecordBatchStreamAdapter::new(
Arc::new((&write_schema).into()),
stream,
));
let mut stream = chunk_stream(stream, batch_size as usize);
while let Some(chunk) = stream.next().await {
writer.write(&chunk?).await?;
}
} else {
writer.write(batches.as_slice()).await?;
}
let (_num_rows, data_file) = writer.finish().await?;
metadata.files.push(data_file);
if dataset.manifest.uses_stable_row_ids() {
lance_table::rowids::version::refresh_row_latest_update_meta_for_full_frag_rewrite_cols(
&mut metadata,
current_version,
)?;
}
updated_fragments.lock().unwrap().push(metadata);
} else {
let update_schema = batches[0].schema();
let read_columns = update_schema.field_names();
let mut updater = fragment
.updater(
Some(&read_columns),
Some((write_schema, dataset.schema().clone())),
None,
)
.await?;
let mut source_batches = Vec::with_capacity(batches.len() + 1);
source_batches.push(batches[0].clone()); for batch in &batches {
source_batches.push(batch.drop_column(ROW_ADDR)?);
}
fn get_row_addr_iter(
batches: &[RecordBatch],
) -> impl Iterator<Item = (u64, (usize, usize))> + '_ + Send
{
batches.iter().enumerate().flat_map(|(batch_idx, batch)| {
let batch_idx = batch_idx + 1;
let row_addrs = batch
.column_by_name(ROW_ADDR)
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
row_addrs
.values()
.iter()
.enumerate()
.map(move |(offset, row_addr)| (*row_addr, (batch_idx, offset)))
})
}
let mut updated_row_addr_iter = get_row_addr_iter(&batches).peekable();
while let Some(batch) = updater.next().await? {
source_batches[0] =
batch.project_by_schema(source_batches[1].schema().as_ref())?;
let original_row_addrs = batch
.column_by_name(ROW_ADDR)
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
let indices = original_row_addrs
.values()
.into_iter()
.enumerate()
.map(|(original_offset, row_addr)| {
match updated_row_addr_iter.peek() {
Some((updated_row_addr, _))
if *updated_row_addr == *row_addr =>
{
updated_row_addr_iter.next().unwrap().1
}
Some((updated_row_addr, _)) => {
debug_assert!(
*updated_row_addr > *row_addr,
"Got updated row address that is not in the original batch"
);
(0, original_offset)
}
_ => (0, original_offset),
}
})
.collect::<Vec<_>>();
let updated_batch = interleave_batches(&source_batches, &indices)?;
updater.update(updated_batch).await?;
}
let mut updated_fragment = updater.finish().await?;
if dataset.manifest.uses_stable_row_ids() {
let mut updated_offsets: Vec<usize> = Vec::new();
for b in batches.iter() {
let row_addrs = b
.column_by_name(ROW_ADDR)
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
updated_offsets.extend(
row_addrs
.values()
.iter()
.map(|addr| RowAddress::from(*addr).row_offset() as usize),
);
}
updated_offsets.sort_unstable();
updated_offsets.dedup();
lance_table::rowids::version::refresh_row_latest_update_meta_for_partial_frag_rewrite_cols(
&mut updated_fragment,
&updated_offsets,
current_version,
dataset.manifest.version,
)?;
}
updated_fragments.lock().unwrap().push(updated_fragment);
}
Ok(reservation_size)
}
async fn handle_new_fragments(
dataset: Arc<Dataset>,
batches: Vec<RecordBatch>,
new_fragments: Arc<Mutex<Vec<Fragment>>>,
reservation_size: usize,
) -> Result<usize> {
let num_fields = batches[0].schema().fields().len();
let mut projection = Vec::with_capacity(num_fields - 1);
for (i, field) in batches[0].schema().fields().iter().enumerate() {
if field.name() != ROW_ADDR {
projection.push(i);
}
}
let write_schema = Arc::new(batches[0].schema().project(&projection).unwrap());
let batches = batches
.into_iter()
.map(move |batch| batch.project(&projection));
let reader = RecordBatchIterator::new(batches, write_schema.clone());
let stream = reader_to_stream(Box::new(reader));
let write_schema = dataset.schema().project_by_schema(
write_schema.as_ref(),
OnMissing::Error,
OnTypeMismatch::Error,
)?;
let (fragments, _) = write_fragments_internal(
Some(dataset.as_ref()),
dataset.object_store.clone(),
&dataset.base,
write_schema,
stream,
Default::default(), None, )
.await?;
new_fragments.lock().unwrap().extend(fragments);
Ok(reservation_size)
}
let mut memory_size = batches
.iter()
.map(|batch| batch.get_array_memory_size())
.sum();
loop {
let have_additional_cpus = tasks.len() < task_limit;
if have_additional_cpus {
if reservation.try_grow(memory_size).is_ok() {
break;
} else if tasks.is_empty() {
memory_size = 0;
break;
}
}
if let Some(res) = tasks.join_next().await {
let size = res??;
reservation.shrink(size);
}
}
match frag_id.first() {
Some(ScalarValue::UInt64(Some(frag_id))) => {
let frag_id = *frag_id;
let fragment = dataset.get_fragment(frag_id as usize).ok_or_else(|| {
error!(
fragment_id = frag_id,
dataset_uri = %dataset.uri(),
manifest_version = dataset.manifest().version,
manifest_path = %dataset.manifest_location().path,
branch = ?dataset.manifest().branch,
"Non-existent fragment id returned from merge result",
);
Error::internal(format!(
"Got non-existent fragment id from merge result: {} (uri={}, version={}, manifest={}, branch={})",
frag_id,
dataset.uri(),
dataset.manifest().version,
dataset.manifest_location().path,
dataset.manifest().branch.as_deref().unwrap_or("main"),
))
})?;
let metadata = fragment.metadata.clone();
let fut = handle_fragment(
dataset.clone(),
fragment,
metadata,
batches,
updated_fragments.clone(),
memory_size,
current_version,
);
tasks.spawn(fut);
}
Some(ScalarValue::Null | ScalarValue::UInt64(None)) => {
let fut = handle_new_fragments(
dataset.clone(),
batches,
new_fragments.clone(),
memory_size,
);
tasks.spawn(fut);
}
_ => {
return Err(Error::internal(format!(
"Got non-fragment id from merge result: {:?}",
frag_id
)));
}
};
}
while let Some(res) = tasks.join_next().await {
let size = res??;
reservation.shrink(size);
}
let mut updated_fragments = Arc::try_unwrap(updated_fragments)
.unwrap()
.into_inner()
.unwrap();
let mut all_fields_updated = HashSet::new();
for fragment in &mut updated_fragments {
let updated_fields = fragment.files.last().unwrap().fields.clone();
all_fields_updated.extend(updated_fields.iter().map(|&f| f as u32));
for data_file in &mut fragment.files.iter_mut().rev().skip(1) {
for field in &mut data_file.fields {
if updated_fields.contains(field) {
*field = -2;
}
}
}
}
let new_fragments = Arc::try_unwrap(new_fragments)
.unwrap()
.into_inner()
.unwrap();
Ok((
updated_fragments,
new_fragments,
all_fields_updated.into_iter().collect(),
))
}
pub async fn execute(
self,
source: SendableRecordBatchStream,
) -> Result<(Arc<Dataset>, MergeStats)> {
let source_iter = super::new_source_iter(source, self.params.conflict_retries > 0).await?;
let dataset = self.dataset.clone();
let config = RetryConfig {
max_retries: self.params.conflict_retries,
retry_timeout: self.params.retry_timeout,
};
let wrapper = MergeInsertJobWithIterator {
job: self,
source_iter: Arc::new(Mutex::new(source_iter)),
attempt_count: Arc::new(AtomicU32::new(0)),
};
Box::pin(execute_with_retry(wrapper, dataset, config)).await
}
pub async fn execute_uncommitted(
self,
source: impl StreamingWriteSource,
) -> Result<UncommittedMergeInsert> {
let stream = source.into_stream();
self.execute_uncommitted_impl(stream).await
}
fn create_plan_join_type(&self) -> JoinType {
let keep_unmatched_source_rows = self.params.insert_not_matched;
let keep_unmatched_target_rows = !matches!(
self.params.delete_not_matched_by_source,
WhenNotMatchedBySource::Keep
);
match (keep_unmatched_target_rows, keep_unmatched_source_rows) {
(false, false) => JoinType::Inner,
(false, true) => JoinType::Right,
(true, false) => JoinType::Left,
(true, true) => JoinType::Full,
}
}
async fn create_plan(
self,
source: SendableRecordBatchStream,
) -> Result<Arc<dyn ExecutionPlan>> {
let session_config = SessionConfig::default();
let session_ctx = SessionContext::new_with_config(session_config);
let scan = session_ctx.read_lance_unordered(self.dataset.clone(), true, true)?;
let on_cols = self
.params
.on
.iter()
.map(|name| format!("\"{}\"", name))
.collect::<Vec<_>>();
let on_cols_refs = on_cols.iter().map(|s| s.as_str()).collect::<Vec<_>>();
let source_df = session_ctx.read_one_shot(source)?;
let source_df_aliased = source_df.alias("source")?;
let scan_aliased = scan.alias("target")?;
let join_type = self.create_plan_join_type();
let dataset_schema: Schema = self.dataset.schema().into();
let df = scan_aliased
.join(
source_df_aliased,
join_type,
&on_cols_refs,
&on_cols_refs,
None,
)?
.with_column(
MERGE_ACTION_COLUMN,
merge_insert_action(&self.params, Some(&dataset_schema))?,
)?;
let (session_state, logical_plan) = df.into_parts();
let write_node = logical_plan::MergeInsertWriteNode::new(
logical_plan,
self.dataset.clone(),
self.params.clone(),
);
let logical_plan = LogicalPlan::Extension(Extension {
node: Arc::new(write_node),
});
let logical_plan = session_state.optimize(&logical_plan)?;
let planner =
DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(MergeInsertPlanner {})]);
let physical_plan = planner
.create_physical_plan(&logical_plan, &session_state)
.await?;
Ok(physical_plan)
}
async fn execute_uncommitted_v2(
self,
source: SendableRecordBatchStream,
) -> Result<(
Transaction,
MergeStats,
Option<RowAddrTreeMap>,
Option<KeyExistenceFilter>,
)> {
let plan = self.create_plan(source).await?;
let partition_count = match plan.properties().output_partitioning() {
datafusion_physical_expr::Partitioning::RoundRobinBatch(n) => *n,
datafusion_physical_expr::Partitioning::Hash(_, n) => *n,
datafusion_physical_expr::Partitioning::UnknownPartitioning(n) => *n,
};
if partition_count != 1 {
return Err(Error::invalid_input(format!(
"Expected exactly 1 partition, got {}",
partition_count
)));
}
let task_context = Arc::new(datafusion::execution::TaskContext::default());
let mut stream = plan.execute(0, task_context)?;
if let Some(batch) = stream.next().await {
let batch = batch?;
if batch.num_rows() > 0 {
return Err(Error::invalid_input(format!(
"Expected no output from write operation, got {} rows",
batch.num_rows()
)));
}
}
let (stats, transaction, affected_rows, inserted_rows_filter) = if let Some(full_exec) =
plan.as_any()
.downcast_ref::<exec::FullSchemaMergeInsertExec>()
{
let stats = full_exec.merge_stats().ok_or_else(|| {
Error::internal("Merge stats not available - execution may not have completed")
})?;
let transaction = full_exec.transaction().ok_or_else(|| {
Error::internal("Transaction not available - execution may not have completed")
})?;
let affected_rows = full_exec.affected_rows().map(RowAddrTreeMap::from);
let inserted_rows_filter = full_exec.inserted_rows_filter();
(stats, transaction, affected_rows, inserted_rows_filter)
} else if let Some(delete_exec) = plan
.as_any()
.downcast_ref::<exec::DeleteOnlyMergeInsertExec>()
{
let stats = delete_exec.merge_stats().ok_or_else(|| {
Error::internal("Merge stats not available - execution may not have completed")
})?;
let transaction = delete_exec.transaction().ok_or_else(|| {
Error::internal("Transaction not available - execution may not have completed")
})?;
let affected_rows = delete_exec.affected_rows().map(RowAddrTreeMap::from);
(stats, transaction, affected_rows, None)
} else {
return Err(Error::internal(
"Expected FullSchemaMergeInsertExec or DeleteOnlyMergeInsertExec",
));
};
Ok((transaction, stats, affected_rows, inserted_rows_filter))
}
async fn can_use_create_plan(&self, source_schema: &Schema) -> Result<bool> {
let lance_schema = lance_core::datatypes::Schema::try_from(source_schema)?;
let full_schema = self.dataset.schema();
let is_full_schema = full_schema.compare_with_options(
&lance_schema,
&SchemaCompareOptions {
compare_metadata: false,
compare_nullability: NullabilityComparison::Ignore,
ignore_field_order: true,
..Default::default()
},
);
let has_scalar_index = self.join_key_as_scalar_index().await?.is_some();
let no_upsert = matches!(
self.params.when_matched,
WhenMatched::Delete | WhenMatched::DoNothing
) && !self.params.insert_not_matched;
let source_has_key_columns = self.params.on.iter().all(|key| {
source_schema
.fields()
.iter()
.any(|f| f.name() == key.as_str())
});
let schema_ok = is_full_schema || (no_upsert && source_has_key_columns);
Ok(matches!(
self.params.when_matched,
WhenMatched::UpdateAll
| WhenMatched::UpdateIf(_)
| WhenMatched::Fail
| WhenMatched::Delete
) && (!self.params.use_index || !has_scalar_index)
&& schema_ok
&& matches!(
self.params.delete_not_matched_by_source,
WhenNotMatchedBySource::Keep
| WhenNotMatchedBySource::Delete
| WhenNotMatchedBySource::DeleteIf(_)
))
}
async fn execute_uncommitted_impl(
self,
source: SendableRecordBatchStream,
) -> Result<UncommittedMergeInsert> {
let can_use_fast_path = self.can_use_create_plan(source.schema().as_ref()).await?;
if can_use_fast_path {
let (transaction, stats, affected_rows, inserted_rows_filter) =
self.execute_uncommitted_v2(source).await?;
return Ok(UncommittedMergeInsert {
transaction,
affected_rows,
stats,
inserted_rows_filter,
});
}
let source_schema = source.schema();
let lance_schema = lance_core::datatypes::Schema::try_from(source_schema.as_ref())?;
let full_schema = self.dataset.schema();
let is_full_schema = full_schema.compare_with_options(
&lance_schema,
&SchemaCompareOptions {
compare_metadata: false,
compare_nullability: NullabilityComparison::Ignore,
..Default::default()
},
);
let joined = self.create_joined_stream(source).await?;
let merger = Merger::try_new(
self.params.clone(),
source_schema,
!is_full_schema,
self.dataset.manifest.uses_stable_row_ids(),
)?;
let merge_statistics = merger.merge_stats.clone();
let deleted_rows = merger.deleted_rows.clone();
let updating_row_ids = merger.updating_row_ids.clone();
let merger_schema = merger.output_schema().clone();
let stream = joined
.and_then(move |batch| merger.clone().execute_batch(batch))
.try_flatten();
let stream = RecordBatchStreamAdapter::new(merger_schema, stream);
let (operation, affected_rows) = if !is_full_schema {
if !matches!(
self.params.delete_not_matched_by_source,
WhenNotMatchedBySource::Keep
) {
return Err(Error::not_supported_source("Deleting rows from the target table when there is no match in the source table is not supported when the source data has a different schema than the target data".into()));
}
let (updated_fragments, new_fragments, fields_modified) = Self::update_fragments(
self.dataset.clone(),
Box::pin(stream),
self.dataset.manifest.version + 1,
)
.await?;
let operation = Operation::Update {
removed_fragment_ids: Vec::new(),
updated_fragments,
new_fragments,
fields_modified,
merged_generations: self.params.merged_generations.clone(),
fields_for_preserving_frag_bitmap: vec![], update_mode: Some(RewriteColumns),
inserted_rows_filter: None, };
(operation, None)
} else {
let (mut new_fragments, _) = write_fragments_internal(
Some(&self.dataset),
self.dataset.object_store.clone(),
&self.dataset.base,
self.dataset.schema().clone(),
Box::pin(stream),
WriteParams::default(),
None, )
.await?;
if let Some(row_id_sequence) = updating_row_ids.lock().unwrap().row_id_sequence() {
let fragment_sizes = new_fragments
.iter()
.map(|f| f.physical_rows.unwrap() as u64);
let sequences = lance_table::rowids::rechunk_sequences(
[row_id_sequence.clone()],
fragment_sizes,
true,
)
.map_err(|e| {
Error::internal(format!(
"Captured row ids not equal to number of rows written: {}",
e
))
})?;
for (fragment, sequence) in new_fragments.iter_mut().zip(sequences) {
let serialized = lance_table::rowids::write_row_ids(&sequence);
fragment.row_id_meta = Some(RowIdMeta::Inline(serialized));
}
}
let removed_row_ids = Arc::into_inner(deleted_rows).unwrap().into_inner().unwrap();
let removed_row_addr_vec =
if let Some(row_id_index) = get_row_id_index(&self.dataset).await? {
let addresses: Vec<u64> = removed_row_ids
.iter()
.filter_map(|id| row_id_index.get(*id).map(|address| address.into()))
.collect::<Vec<_>>();
addresses
} else {
removed_row_ids
};
let removed_row_addrs = RoaringTreemap::from_iter(removed_row_addr_vec.into_iter());
let (old_fragments, removed_fragment_ids) =
Self::apply_deletions(&self.dataset, &removed_row_addrs).await?;
let operation = Operation::Update {
removed_fragment_ids,
updated_fragments: old_fragments,
new_fragments,
fields_modified: vec![],
merged_generations: self.params.merged_generations.clone(),
fields_for_preserving_frag_bitmap: full_schema
.fields
.iter()
.map(|f| f.id as u32)
.collect(),
update_mode: Some(RewriteRows),
inserted_rows_filter: None, };
let affected_rows = Some(RowAddrTreeMap::from(removed_row_addrs));
(operation, affected_rows)
};
let stats = Arc::into_inner(merge_statistics)
.unwrap()
.into_inner()
.unwrap();
let transaction = Transaction::new(self.dataset.manifest.version, operation, None);
Ok(UncommittedMergeInsert {
transaction,
affected_rows,
stats,
inserted_rows_filter: None, })
}
async fn apply_deletions(
dataset: &Dataset,
removed_row_ids: &RoaringTreemap,
) -> Result<(Vec<Fragment>, Vec<u64>)> {
let bitmaps = Arc::new(removed_row_ids.bitmaps().collect::<BTreeMap<_, _>>());
enum FragmentChange {
Unchanged,
Modified(Box<Fragment>),
Removed(u64),
}
let mut updated_fragments = Vec::new();
let mut removed_fragments = Vec::new();
let mut stream = futures::stream::iter(dataset.get_fragments())
.map(move |fragment| {
let bitmaps_ref = bitmaps.clone();
async move {
let fragment_id = fragment.id();
if let Some(bitmap) = bitmaps_ref.get(&(fragment_id as u32)) {
match fragment.extend_deletions(*bitmap).await {
Ok(Some(new_fragment)) => {
Ok(FragmentChange::Modified(Box::new(new_fragment.metadata)))
}
Ok(None) => Ok(FragmentChange::Removed(fragment_id as u64)),
Err(e) => Err(e),
}
} else {
Ok(FragmentChange::Unchanged)
}
}
})
.buffer_unordered(dataset.object_store.io_parallelism());
while let Some(res) = stream.next().await.transpose()? {
match res {
FragmentChange::Unchanged => {}
FragmentChange::Modified(fragment) => updated_fragments.push(*fragment),
FragmentChange::Removed(fragment_id) => removed_fragments.push(fragment_id),
}
}
Ok((updated_fragments, removed_fragments))
}
pub async fn explain_plan(&self, schema: Option<&Schema>, verbose: bool) -> Result<String> {
let schema = match schema {
Some(s) => s.clone(),
None => arrow_schema::Schema::from(self.dataset.schema()),
};
if !self.can_use_create_plan(&schema).await? {
return Err(Error::not_supported_source("This merge insert configuration does not support explain_plan. Only full-schema merge insert operations without a scalar-index execution path are currently supported.".into()));
}
let empty_batch = RecordBatch::new_empty(Arc::new(schema.clone()));
let stream = RecordBatchStreamAdapter::new(
Arc::new(schema.clone()),
futures::stream::once(async { Ok(empty_batch) }).boxed(),
);
let cloned_job = self.clone();
let plan = cloned_job.create_plan(Box::pin(stream)).await?;
let display = DisplayableExecutionPlan::new(plan.as_ref());
Ok(format!("{}", display.indent(verbose)))
}
pub async fn analyze_plan(&self, source: SendableRecordBatchStream) -> Result<String> {
if !self.can_use_create_plan(source.schema().as_ref()).await? {
return Err(Error::not_supported_source("This merge insert configuration does not support analyze_plan. Only full-schema merge insert operations without a scalar-index execution path are currently supported.".into()));
}
let cloned_job = self.clone();
let plan = cloned_job.create_plan(source).await?;
let options = LanceExecutionOptions::default();
let full_analysis = analyze_plan(plan, options).await?;
let lines: Vec<&str> = full_analysis.lines().collect();
let filtered_lines: Vec<&str> = lines
.into_iter()
.filter(|line| {
!line.trim_start().starts_with("AnalyzeExec")
&& !line.trim_start().starts_with("TracedExec")
})
.collect();
Ok(filtered_lines.join("\n"))
}
}
#[derive(Debug, Default, Clone)]
pub struct MergeStats {
pub num_inserted_rows: u64,
pub num_updated_rows: u64,
pub num_deleted_rows: u64,
pub num_attempts: u32,
pub bytes_written: u64,
pub num_files_written: u64,
pub num_skipped_duplicates: u64,
}
pub struct UncommittedMergeInsert {
pub transaction: Transaction,
pub affected_rows: Option<RowAddrTreeMap>,
pub stats: MergeStats,
pub inserted_rows_filter: Option<KeyExistenceFilter>,
}
#[derive(Clone)]
struct MergeInsertJobWithIterator {
job: MergeInsertJob,
source_iter: Arc<Mutex<Box<dyn Iterator<Item = SendableRecordBatchStream> + Send + 'static>>>,
attempt_count: Arc<AtomicU32>,
}
impl RetryExecutor for MergeInsertJobWithIterator {
type Data = UncommittedMergeInsert;
type Result = (Arc<Dataset>, MergeStats);
async fn execute_impl(&self) -> Result<Self::Data> {
self.attempt_count.fetch_add(1, Ordering::SeqCst);
let stream = self.source_iter.lock().unwrap().next().unwrap();
self.job.clone().execute_uncommitted_impl(stream).await
}
async fn commit(&self, dataset: Arc<Dataset>, mut data: Self::Data) -> Result<Self::Result> {
data.stats.num_attempts = self.attempt_count.load(Ordering::SeqCst);
let mut commit_builder =
CommitBuilder::new(dataset).with_skip_auto_cleanup(self.job.params.skip_auto_cleanup);
if let Some(commit_retries) = self.job.params.commit_retries {
commit_builder = commit_builder.with_max_retries(commit_retries);
}
if let Some(affected_rows) = data.affected_rows {
commit_builder = commit_builder.with_affected_rows(affected_rows);
}
let new_dataset = commit_builder.execute(data.transaction).await?;
Ok((Arc::new(new_dataset), data.stats))
}
fn update_dataset(&mut self, dataset: Arc<Dataset>) {
self.job.dataset = dataset;
}
}
#[derive(Debug, Clone)]
struct Merger {
deleted_rows: Arc<Mutex<Vec<u64>>>,
updating_row_ids: Arc<Mutex<CapturedRowIds>>,
delete_expr: Option<Arc<dyn PhysicalExpr>>,
merge_stats: Arc<Mutex<MergeStats>>,
match_filter_expr: Option<Arc<dyn PhysicalExpr>>,
params: MergeInsertParams,
schema: Arc<Schema>,
with_row_addr: bool,
output_schema: Arc<Schema>,
enable_stable_row_ids: bool,
processed_row_ids: Arc<Mutex<HashSet<u64>>>,
}
impl Merger {
fn try_new(
params: MergeInsertParams,
schema: Arc<Schema>,
with_row_addr: bool,
enable_stable_row_ids: bool,
) -> Result<Self> {
let delete_expr = if let WhenNotMatchedBySource::DeleteIf(expr) =
¶ms.delete_not_matched_by_source
{
let planner = Planner::new(schema.clone());
let expr = planner.optimize_expr(expr.clone())?;
let physical_expr = planner.create_physical_expr(&expr)?;
let data_type = physical_expr.data_type(&schema)?;
if data_type != DataType::Boolean {
return Err(Error::invalid_input(format!(
"Merge insert conditions must be expressions that return a boolean value, received expression ({}) which has data type {}",
expr, data_type
)));
}
Some(physical_expr)
} else {
None
};
let match_filter_expr = if let WhenMatched::UpdateIf(expr_str) = ¶ms.when_matched {
let combined_schema = Arc::new(combined_schema(&schema));
let planner = Planner::new(combined_schema.clone());
let expr = planner.parse_filter(expr_str)?;
let expr = planner.optimize_expr(expr)?;
let match_expr = planner.create_physical_expr(&expr)?;
let data_type = match_expr.data_type(combined_schema.as_ref())?;
if data_type != DataType::Boolean {
return Err(Error::invalid_input(format!(
"Merge insert conditions must be expressions that return a boolean value, received a 'when matched update if' expression ({}) which has data type {}",
expr, data_type
)));
}
Some(match_expr)
} else {
None
};
let output_schema = if with_row_addr {
Arc::new(schema.try_with_column(ROW_ADDR_FIELD.clone())?)
} else {
schema.clone()
};
Ok(Self {
deleted_rows: Arc::new(Mutex::new(Vec::new())),
updating_row_ids: Arc::new(Mutex::new(CapturedRowIds::new(enable_stable_row_ids))),
delete_expr,
merge_stats: Arc::new(Mutex::new(MergeStats::default())),
match_filter_expr,
params,
schema,
with_row_addr,
output_schema,
enable_stable_row_ids,
processed_row_ids: Arc::new(Mutex::new(HashSet::new())),
})
}
fn output_schema(&self) -> &Arc<Schema> {
&self.output_schema
}
fn not_all_null(
batch: &RecordBatch,
col_offset: usize,
num_cols: usize,
) -> Result<BooleanArray> {
debug_assert_ne!(num_cols, 0);
let mut at_least_one_valid = arrow::compute::is_not_null(batch.column(col_offset))?;
for idx in col_offset + 1..col_offset + num_cols {
let is_valid = arrow::compute::is_not_null(batch.column(idx))?;
at_least_one_valid = arrow::compute::or(&at_least_one_valid, &is_valid)?;
}
Ok(at_least_one_valid)
}
fn extract_selections(
&self,
combined_batch: &RecordBatch,
right_offset: usize,
num_keys: usize,
) -> Result<(BooleanArray, BooleanArray, BooleanArray)> {
let in_left = Self::not_all_null(combined_batch, 0, num_keys)?;
let in_right = Self::not_all_null(combined_batch, right_offset, num_keys)?;
let in_both = arrow::compute::and(&in_left, &in_right)?;
let left_only = arrow::compute::and(&in_left, &arrow::compute::not(&in_right)?)?;
let right_only = arrow::compute::and(&arrow::compute::not(&in_left)?, &in_right)?;
Ok((left_only, in_both, right_only))
}
async fn execute_batch(
self,
batch: RecordBatch,
) -> datafusion::common::Result<impl Stream<Item = datafusion::common::Result<RecordBatch>>>
{
let mut merge_statistics = self.merge_stats.lock().unwrap();
let num_fields = batch.schema().fields.len();
let (row_id_col, row_addr_col, right_offset) = if num_fields % 2 == 1 {
assert!(!self.with_row_addr);
(num_fields - 1, None, num_fields / 2)
} else {
assert!(self.with_row_addr);
(num_fields - 2, Some(num_fields - 1), (num_fields - 2) / 2)
};
let num_keys = self.params.on.len();
let left_cols = Vec::from_iter(0..right_offset);
let right_cols_with_id = Vec::from_iter(right_offset..num_fields);
let mut batches = Vec::with_capacity(2);
let (left_only, in_both, right_only) =
self.extract_selections(&batch, right_offset, num_keys)?;
let mut deleted_row_ids = self.deleted_rows.lock().unwrap();
if self.params.when_matched != WhenMatched::DoNothing {
let mut matched = arrow::compute::filter_record_batch(&batch, &in_both)?;
if let Some(match_filter) = self.match_filter_expr {
let unzipped = unzip_batch(&matched, &self.schema);
let filtered = match_filter.evaluate(&unzipped)?;
match filtered {
ColumnarValue::Array(mask) => {
matched = arrow::compute::filter_record_batch(&matched, mask.as_boolean())?;
}
ColumnarValue::Scalar(scalar) => {
if let ScalarValue::Boolean(Some(true)) = scalar {
} else {
matched = RecordBatch::new_empty(matched.schema());
}
}
}
}
merge_statistics.num_updated_rows += matched.num_rows() as u64;
if matched.num_rows() > 0 {
let row_ids = matched.column(row_id_col).as_primitive::<UInt64Type>();
let mut processed_row_ids = self.processed_row_ids.lock().unwrap();
let mut keep_indices: Vec<u32> = Vec::with_capacity(matched.num_rows());
for (row_idx, &row_id) in row_ids.values().iter().enumerate() {
if processed_row_ids.insert(row_id) {
keep_indices.push(row_idx as u32);
} else {
match self.params.source_dedupe_behavior {
SourceDedupeBehavior::Fail => {
return Err(create_duplicate_row_error(
&matched,
row_idx,
&self.params.on,
));
}
SourceDedupeBehavior::FirstSeen => {
}
}
}
}
drop(processed_row_ids);
let num_skipped = matched.num_rows() - keep_indices.len();
if num_skipped > 0 {
merge_statistics.num_skipped_duplicates += num_skipped as u64;
merge_statistics.num_updated_rows -= num_skipped as u64;
let indices = UInt32Array::from(keep_indices);
matched = take_record_batch(&matched, &indices)?;
}
if matched.num_rows() > 0 {
let row_ids = matched.column(row_id_col).as_primitive::<UInt64Type>();
deleted_row_ids.extend(row_ids.values());
if self.enable_stable_row_ids {
self.updating_row_ids
.lock()
.unwrap()
.capture(row_ids.values())?;
}
let projection = if let Some(row_addr_col) = row_addr_col {
let mut cols = Vec::from_iter(left_cols.iter().cloned());
cols.push(row_addr_col);
cols
} else {
#[allow(clippy::redundant_clone)]
left_cols.clone()
};
let matched = matched.project(&projection)?;
let matched = RecordBatch::try_new(
self.output_schema.clone(),
Vec::from_iter(matched.columns().iter().cloned()),
)?;
batches.push(Ok(matched));
}
}
}
if self.params.insert_not_matched {
let not_matched = arrow::compute::filter_record_batch(&batch, &left_only)?;
let left_cols_with_id = left_cols
.into_iter()
.chain(row_addr_col)
.collect::<Vec<_>>();
let not_matched = not_matched.project(&left_cols_with_id)?;
let not_matched = RecordBatch::try_new(
self.output_schema.clone(),
Vec::from_iter(not_matched.columns().iter().cloned()),
)?;
merge_statistics.num_inserted_rows += not_matched.num_rows() as u64;
batches.push(Ok(not_matched));
}
match self.params.delete_not_matched_by_source {
WhenNotMatchedBySource::Delete => {
let unmatched = arrow::compute::filter(batch.column(row_id_col), &right_only)?;
merge_statistics.num_deleted_rows += unmatched.len() as u64;
let row_ids = unmatched.as_primitive::<UInt64Type>();
deleted_row_ids.extend(row_ids.values());
}
WhenNotMatchedBySource::DeleteIf(_) => {
let target_data = batch.project(&right_cols_with_id)?;
let unmatched = arrow::compute::filter_record_batch(&target_data, &right_only)?;
let row_id_col = unmatched.num_columns() - 1;
let to_delete = self.delete_expr.unwrap().evaluate(&unmatched)?;
match to_delete {
ColumnarValue::Array(mask) => {
let row_ids = arrow::compute::filter(
unmatched.column(row_id_col),
mask.as_boolean(),
)?;
let row_ids = row_ids.as_primitive::<UInt64Type>();
merge_statistics.num_deleted_rows += row_ids.len() as u64;
deleted_row_ids.extend(row_ids.values());
}
ColumnarValue::Scalar(scalar) => {
if let ScalarValue::Boolean(Some(true)) = scalar {
let row_ids = unmatched.column(row_id_col).as_primitive::<UInt64Type>();
merge_statistics.num_deleted_rows += row_ids.len() as u64;
deleted_row_ids.extend(row_ids.values());
}
}
}
}
WhenNotMatchedBySource::Keep => {}
}
Ok(stream::iter(batches))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::scanner::ColumnOrdering;
use crate::dataset::write::merge_insert::inserted_rows::{
KeyExistenceFilter, KeyExistenceFilterBuilder, extract_key_value_from_batch,
};
use crate::index::vector::VectorIndexParams;
use crate::io::commit::read_transaction_file;
use crate::{
dataset::{InsertBuilder, ReadParams, WriteMode, WriteParams, builder::DatasetBuilder},
session::Session,
utils::test::{
DatagenExt, FragmentCount, FragmentRowCount, ThrottledStoreWrapper,
assert_plan_node_equals, assert_string_matches,
},
};
use arrow_array::RecordBatch;
use arrow_array::builder::{ListBuilder, StringBuilder};
use arrow_array::types::Float32Type;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, Float64Array, Int32Array, Int64Array, ListArray,
RecordBatchIterator, RecordBatchReader, StringArray, StructArray, UInt32Array,
types::{Int32Type, UInt32Type},
};
use arrow_buffer::{OffsetBuffer, ScalarBuffer};
use arrow_schema::{DataType, Field, Schema};
use arrow_select::concat::concat_batches;
use datafusion::common::Column;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{FutureExt, StreamExt, TryStreamExt, future::try_join_all};
use lance_arrow::FixedSizeListArrayExt;
use lance_core::utils::tempfile::TempStrDir;
use lance_datafusion::{datagen::DatafusionDatagenExt, utils::reader_to_stream};
use lance_datagen::{BatchCount, Dimension, RowCount, Seed, array};
use lance_index::IndexType;
use lance_index::scalar::ScalarIndexParams;
use lance_io::object_store::ObjectStoreParams;
use lance_linalg::distance::MetricType;
use mock_instant::thread_local::MockClock;
use object_store::throttle::ThrottleConfig;
use roaring::RoaringBitmap;
use std::collections::HashMap;
use tokio::sync::{Barrier, Notify};
fn assert_send<T: Send>(t: T) -> T {
t
}
async fn check_then_refresh_dataset(
new_data: RecordBatch,
mut job: MergeInsertJob,
keys_from_left: &[u32],
keys_from_right: &[u32],
stats: &[u64],
) -> Arc<Dataset> {
let mut dataset = (*job.dataset).clone();
dataset.restore().await.unwrap();
job.dataset = Arc::new(dataset);
let schema = new_data.schema();
let new_reader = Box::new(RecordBatchIterator::new([Ok(new_data)], schema.clone()));
let new_stream = reader_to_stream(new_reader);
let (merged_dataset, merge_stats) = job.execute(new_stream).boxed().await.unwrap();
let batches = merged_dataset
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let merged = concat_batches(&schema, &batches).unwrap();
let keyvals = merged
.column(0)
.as_primitive::<UInt32Type>()
.values()
.iter()
.zip(
merged
.column(1)
.as_primitive::<UInt32Type>()
.values()
.iter(),
);
let mut left_keys = keyvals
.clone()
.filter(|&(_, &val)| val == 1)
.map(|(key, _)| key)
.copied()
.collect::<Vec<_>>();
let mut right_keys = keyvals
.clone()
.filter(|&(_, &val)| val == 2)
.map(|(key, _)| key)
.copied()
.collect::<Vec<_>>();
left_keys.sort();
right_keys.sort();
assert_eq!(left_keys, keys_from_left);
assert_eq!(right_keys, keys_from_right);
assert_eq!(merge_stats.num_inserted_rows, stats[0]);
assert_eq!(merge_stats.num_updated_rows, stats[1]);
assert_eq!(merge_stats.num_deleted_rows, stats[2]);
merged_dataset
}
fn create_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("key", DataType::UInt32, true),
Field::new("value", DataType::UInt32, true),
Field::new("filterme", DataType::Utf8, true),
]))
}
fn create_new_batch(schema: Arc<Schema>) -> RecordBatch {
RecordBatch::try_new(
schema,
vec![
Arc::new(UInt32Array::from(vec![4, 5, 6, 7, 8, 9])),
Arc::new(UInt32Array::from(vec![2, 2, 2, 2, 2, 2])),
Arc::new(StringArray::from(vec!["A", "B", "C", "A", "B", "C"])),
],
)
.unwrap()
}
async fn create_test_dataset(
test_uri: &str,
version: LanceFileVersion,
enable_stable_row_ids: bool,
) -> Arc<Dataset> {
let dataset = lance_datagen::gen_batch()
.col("key", array::step_custom::<UInt32Type>(1, 1))
.col("value", array::fill::<UInt32Type>(1u32))
.col(
"filterme",
array::cycle_utf8_literals(&["A", "B", "A", "A", "B", "A"]),
)
.into_dataset_with_params(
test_uri,
FragmentCount(2),
FragmentRowCount(3),
Some(WriteParams {
max_rows_per_file: 3,
data_storage_version: Some(version),
enable_stable_row_ids,
..Default::default()
}),
)
.await
.unwrap();
assert_eq!(2, dataset.get_fragments().len());
Arc::new(dataset)
}
async fn get_row_ids_for_keys(dataset: &Dataset, keys: &[u32]) -> UInt64Array {
let filter = format!(
"key IN ({})",
keys.iter()
.map(|k| k.to_string())
.collect::<Vec<_>>()
.join(",")
);
let batch = dataset
.scan()
.filter(&filter)
.unwrap()
.with_row_id()
.order_by(Some(vec![ColumnOrdering::asc_nulls_first(
"key".to_string(),
)]))
.unwrap()
.try_into_batch()
.await
.unwrap();
batch
.column_by_name(ROW_ID)
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap()
.clone()
}
fn create_delete_condition() -> Expr {
Expr::gt(
Expr::Column(Column::new_unqualified("key")),
Expr::Literal(ScalarValue::UInt32(Some(1)), None),
)
}
struct MergeInsertTestBuilder {
version: LanceFileVersion,
enable_stable_row_ids: bool,
test_keys: Vec<u32>,
expected_left_keys: Vec<u32>,
expected_right_keys: Vec<u32>,
expected_stats: Vec<u64>,
job_builder: Option<Box<dyn FnOnce(Arc<Dataset>) -> MergeInsertJob>>,
}
impl MergeInsertTestBuilder {
fn new() -> Self {
Self {
version: LanceFileVersion::default(),
enable_stable_row_ids: false,
test_keys: vec![],
expected_left_keys: vec![],
expected_right_keys: vec![],
expected_stats: vec![],
job_builder: None,
}
}
fn with_version(mut self, version: LanceFileVersion) -> Self {
self.version = version;
self
}
fn with_stable_row_ids(mut self, enable: bool) -> Self {
self.enable_stable_row_ids = enable;
self
}
fn with_test_keys(mut self, keys: &[u32]) -> Self {
self.test_keys = keys.to_vec();
self
}
fn with_expected_left_keys(mut self, keys: &[u32]) -> Self {
self.expected_left_keys = keys.to_vec();
self
}
fn with_expected_right_keys(mut self, keys: &[u32]) -> Self {
self.expected_right_keys = keys.to_vec();
self
}
fn with_expected_stats(mut self, stats: &[u64]) -> Self {
self.expected_stats = stats.to_vec();
self
}
fn with_job_builder<F>(mut self, builder: F) -> Self
where
F: FnOnce(Arc<Dataset>) -> MergeInsertJob + 'static,
{
self.job_builder = Some(Box::new(builder));
self
}
async fn run_test(self) {
let schema = create_test_schema();
let new_batch = create_new_batch(schema.clone());
let test_uri = "memory://test.lance";
let ds = create_test_dataset(test_uri, self.version, self.enable_stable_row_ids).await;
let row_ids_before = get_row_ids_for_keys(&ds, &self.test_keys).await;
let job_builder = self.job_builder.expect("job_builder must be set");
let job = job_builder(ds);
let ds = check_then_refresh_dataset(
new_batch,
job,
&self.expected_left_keys,
&self.expected_right_keys,
&self.expected_stats,
)
.await;
let row_ids_after = get_row_ids_for_keys(&ds, &self.test_keys).await;
if self.enable_stable_row_ids {
assert_eq!(row_ids_before, row_ids_after);
} else {
assert_ne!(row_ids_before, row_ids_after);
}
}
}
#[tokio::test]
async fn test_merge_insert_requires_on_or_primary_key() {
let test_uri = "memory://merge_insert_requires_keys";
let ds = create_test_dataset(test_uri, LanceFileVersion::V2_0, false).await;
let err = MergeInsertBuilder::try_new(ds, Vec::new()).unwrap_err();
if let crate::Error::InvalidInput { source, .. } = err {
let msg = source.to_string();
assert!(
msg.contains("requires join keys") && msg.contains("primary key"),
"unexpected error message: {}",
msg
);
} else {
panic!("expected InvalidInput error");
}
}
#[tokio::test]
async fn test_merge_insert_defaults_to_unenforced_primary_key() {
let id_field = Field::new("id", DataType::Int32, false).with_metadata(
[(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
)]
.into(),
);
let value_field = Field::new("value", DataType::Int32, false);
let schema = Arc::new(Schema::new(vec![id_field, value_field]));
let initial_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![10, 20, 30])),
],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone());
let dataset = Dataset::write(
reader,
"memory://merge_insert_pk_default",
Some(WriteParams {
data_storage_version: Some(LanceFileVersion::V2_0),
..Default::default()
}),
)
.await
.unwrap();
let dataset = Arc::new(dataset);
let new_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![2, 3, 4])),
Arc::new(Int32Array::from(vec![200, 300, 400])),
],
)
.unwrap();
let mut builder = MergeInsertBuilder::try_new(dataset.clone(), Vec::new()).unwrap();
builder
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll);
let job = builder.try_build().unwrap();
let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
let new_stream = reader_to_stream(new_reader);
let (updated_dataset, stats) = job.execute(new_stream).await.unwrap();
assert_eq!(stats.num_inserted_rows, 1);
assert_eq!(stats.num_updated_rows, 2);
assert_eq!(stats.num_deleted_rows, 0);
let result_batch = updated_dataset.scan().try_into_batch().await.unwrap();
let ids = result_batch
.column_by_name("id")
.unwrap()
.as_primitive::<Int32Type>();
let values = result_batch
.column_by_name("value")
.unwrap()
.as_primitive::<Int32Type>();
let mut pairs = (0..ids.len())
.map(|i| (ids.value(i), values.value(i)))
.collect::<Vec<_>>();
pairs.sort_unstable();
assert_eq!(pairs, vec![(1, 10), (2, 200), (3, 300), (4, 400)]);
}
#[rstest::rstest]
#[tokio::test]
async fn test_basic_merge(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
) {
let schema = create_test_schema();
let new_batch = create_new_batch(schema.clone());
let test_uri = "memory://test.lance";
let ds = create_test_dataset(test_uri, version, false).await;
assert!(MergeInsertBuilder::try_new(ds.clone(), vec![]).is_err());
let keys = vec!["key".to_string()];
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.try_build()
.unwrap();
check_then_refresh_dataset(
new_batch.clone(),
job,
&[1, 2, 3, 4, 5, 6],
&[7, 8, 9],
&[3, 0, 0],
)
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.try_build()
.unwrap();
check_then_refresh_dataset(
new_batch.clone(),
job,
&[1, 2, 3],
&[4, 5, 6, 7, 8, 9],
&[3, 3, 0],
)
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(
WhenMatched::update_if(&ds, "source.filterme != target.filterme").unwrap(),
)
.try_build()
.unwrap();
check_then_refresh_dataset(
new_batch.clone(),
job,
&[1, 2, 3, 4, 5],
&[6, 7, 8, 9],
&[3, 1, 0],
)
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_not_matched(WhenNotMatched::DoNothing)
.when_matched(WhenMatched::update_if(&ds, "target.filterme = 'z'").unwrap())
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[1, 2, 3, 4, 5, 6], &[], &[0, 0, 0])
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[1, 2, 3], &[4, 5, 6], &[0, 3, 0])
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(
WhenMatched::update_if(&ds, "source.filterme == target.filterme").unwrap(),
)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[1, 2, 3, 6], &[4, 5], &[0, 2, 0])
.await;
assert!(
MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.is_err()
);
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[4, 5, 6], &[7, 8, 9], &[3, 0, 3])
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[], &[4, 5, 6, 7, 8, 9], &[3, 3, 3])
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(
WhenMatched::update_if(&ds, "source.filterme != target.filterme").unwrap(),
)
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[4, 5], &[6, 7, 8, 9], &[3, 1, 3])
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::DoNothing)
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[], &[4, 5, 6], &[0, 3, 3]).await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_not_matched(WhenNotMatched::DoNothing)
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[4, 5, 6], &[], &[0, 0, 3]).await;
let condition = create_delete_condition();
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
.try_build()
.unwrap();
check_then_refresh_dataset(
new_batch.clone(),
job,
&[1, 4, 5, 6],
&[7, 8, 9],
&[3, 0, 2],
)
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
.try_build()
.unwrap();
check_then_refresh_dataset(
new_batch.clone(),
job,
&[1],
&[4, 5, 6, 7, 8, 9],
&[3, 3, 2],
)
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(
WhenMatched::update_if(&ds, "source.filterme != target.filterme").unwrap(),
)
.when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
.try_build()
.unwrap();
check_then_refresh_dataset(
new_batch.clone(),
job,
&[1, 4, 5],
&[6, 7, 8, 9],
&[3, 1, 2],
)
.await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::DoNothing)
.when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[1], &[4, 5, 6], &[0, 3, 2]).await;
let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_not_matched(WhenNotMatched::DoNothing)
.when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
.try_build()
.unwrap();
check_then_refresh_dataset(new_batch.clone(), job, &[1, 4, 5, 6], &[], &[0, 0, 2]).await;
}
#[rstest::rstest]
#[tokio::test]
async fn test_upsert_and_delete_all_with_stable_row_id(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
#[values(true, false)] enable_stable_row_ids: bool,
) {
MergeInsertTestBuilder::new()
.with_version(version)
.with_stable_row_ids(enable_stable_row_ids)
.with_test_keys(&[4, 5, 6])
.with_expected_left_keys(&[])
.with_expected_right_keys(&[4, 5, 6, 7, 8, 9])
.with_expected_stats(&[3, 3, 3])
.with_job_builder(|ds| {
MergeInsertBuilder::try_new(ds, vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.try_build()
.unwrap()
})
.run_test()
.await;
}
#[rstest::rstest]
#[tokio::test]
async fn test_upsert_only_with_stable_row_id(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
#[values(true, false)] enable_stable_row_ids: bool,
) {
MergeInsertTestBuilder::new()
.with_version(version)
.with_stable_row_ids(enable_stable_row_ids)
.with_test_keys(&[4, 5, 6])
.with_expected_left_keys(&[1, 2, 3])
.with_expected_right_keys(&[4, 5, 6, 7, 8, 9])
.with_expected_stats(&[3, 3, 0])
.with_job_builder(|ds| {
MergeInsertBuilder::try_new(ds, vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.try_build()
.unwrap()
})
.run_test()
.await;
}
#[rstest::rstest]
#[tokio::test]
async fn test_conditional_update_with_stable_row_id(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
#[values(true, false)] enable_stable_row_ids: bool,
) {
MergeInsertTestBuilder::new()
.with_version(version)
.with_stable_row_ids(enable_stable_row_ids)
.with_test_keys(&[6])
.with_expected_left_keys(&[1, 2, 3, 4, 5])
.with_expected_right_keys(&[6, 7, 8, 9])
.with_expected_stats(&[3, 1, 0])
.with_job_builder(|ds| {
let keys = vec!["key".to_string()];
MergeInsertBuilder::try_new(ds.clone(), keys)
.unwrap()
.when_matched(
WhenMatched::update_if(&ds, "source.filterme != target.filterme").unwrap(),
)
.try_build()
.unwrap()
})
.run_test()
.await;
}
#[rstest::rstest]
#[tokio::test]
async fn test_update_only_with_stable_row_id(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
#[values(true, false)] enable_stable_row_ids: bool,
) {
MergeInsertTestBuilder::new()
.with_version(version)
.with_stable_row_ids(enable_stable_row_ids)
.with_test_keys(&[4, 5, 6])
.with_expected_left_keys(&[1, 2, 3])
.with_expected_right_keys(&[4, 5, 6])
.with_expected_stats(&[0, 3, 0])
.with_job_builder(|ds| {
let keys = vec!["key".to_string()];
MergeInsertBuilder::try_new(ds, keys)
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap()
})
.run_test()
.await;
}
#[rstest::rstest]
#[tokio::test]
async fn test_upsert_with_conditional_delete_and_stable_row_id(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
#[values(true, false)] enable_stable_row_ids: bool,
) {
MergeInsertTestBuilder::new()
.with_version(version)
.with_stable_row_ids(enable_stable_row_ids)
.with_test_keys(&[1, 4, 5, 6])
.with_expected_left_keys(&[1])
.with_expected_right_keys(&[4, 5, 6, 7, 8, 9])
.with_expected_stats(&[3, 3, 2])
.with_job_builder(|ds| {
let keys = vec!["key".to_string()];
let condition = create_delete_condition();
MergeInsertBuilder::try_new(ds, keys)
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition))
.try_build()
.unwrap()
})
.run_test()
.await;
}
#[rstest::rstest]
#[tokio::test]
async fn test_multiple_merge_insert_stable_row_id(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
#[values(true, false)] enable_stable_row_ids: bool,
) {
let schema = create_test_schema();
let test_uri = "memory://test_multiple_merge.lance";
let ds = create_test_dataset(test_uri, version, enable_stable_row_ids).await;
let target_key = 2u32;
let target_keys = vec![target_key];
let initial_row_ids = get_row_ids_for_keys(&ds, &target_keys).await;
let initial_row_id = initial_row_ids.value(0);
let mut current_ds = ds;
for iteration in 1..=3 {
let new_value = 1000u32 + iteration * 10;
let new_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![target_key])), Arc::new(UInt32Array::from(vec![new_value])), Arc::new(StringArray::from(vec![format!("iteration_{}", iteration)])), ],
)
.unwrap();
let job = MergeInsertBuilder::try_new(current_ds.clone(), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
let new_stream = reader_to_stream(new_reader);
let (updated_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
assert_eq!(
merge_stats.num_updated_rows, 1,
"Iteration {}: Expected 1 updated row",
iteration
);
assert_eq!(
merge_stats.num_inserted_rows, 0,
"Iteration {}: Expected 0 inserted rows",
iteration
);
assert_eq!(
merge_stats.num_deleted_rows, 0,
"Iteration {}: Expected 0 deleted rows",
iteration
);
let updated_row_ids = get_row_ids_for_keys(&updated_dataset, &target_keys).await;
let updated_row_id = updated_row_ids.value(0);
let updated_batch = updated_dataset
.scan()
.filter(&format!("key = {}", target_key))
.unwrap()
.try_into_batch()
.await
.unwrap();
let value_col = updated_batch
.column_by_name("value")
.unwrap()
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
let filterme_col = updated_batch
.column_by_name("filterme")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(
value_col.value(0),
new_value,
"Iteration {}: Value should be updated to {}",
iteration,
new_value
);
assert_eq!(filterme_col.value(0), format!("iteration_{}", iteration));
if enable_stable_row_ids {
assert_eq!(
updated_row_id, initial_row_id,
"Iteration {}: Row ID should remain stable across merge inserts when stable_row_ids is enabled. Initial: {}, Current: {}",
iteration, initial_row_id, updated_row_id
);
}
current_ds = updated_dataset;
}
let final_batch = current_ds
.scan()
.filter(&format!("key = {}", target_key))
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(
final_batch.num_rows(),
1,
"Should have exactly one row for the target key"
);
let final_value = final_batch
.column_by_name("value")
.unwrap()
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap()
.value(0);
let final_filterme = final_batch
.column_by_name("filterme")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.value(0);
assert_eq!(
final_value, 1030u32,
"Final value should be from last iteration"
);
assert_eq!(
final_filterme, "iteration_3",
"Final filterme should be from last iteration"
);
}
#[rstest::rstest]
#[tokio::test]
async fn test_row_id_stability_across_update_and_merge_insert(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
#[values(true, false)] enable_stable_row_ids: bool,
) {
let schema = create_test_schema();
let test_uri = "memory://test_row_id_stability.lance";
let mut dataset = create_test_dataset(test_uri, version, enable_stable_row_ids).await;
let target_key = 2u32;
let target_keys = vec![target_key];
let initial_row_ids = get_row_ids_for_keys(&dataset, &target_keys).await;
let initial_row_id = initial_row_ids.value(0);
let initial_batch = dataset
.scan()
.filter(&format!("key = {}", target_key))
.unwrap()
.with_row_id()
.try_into_batch()
.await
.unwrap();
let initial_value = initial_batch
.column_by_name("value")
.unwrap()
.as_primitive::<UInt32Type>()
.value(0);
let update_result = crate::dataset::UpdateBuilder::new(Arc::new((*dataset).clone()))
.update_where(&format!("key = {}", target_key))
.unwrap()
.set("value", "value + 100")
.unwrap()
.build()
.unwrap()
.execute()
.await
.unwrap();
dataset = update_result.new_dataset.clone();
let after_update_row_ids = get_row_ids_for_keys(&dataset, &target_keys).await;
let after_update_row_id = after_update_row_ids.value(0);
let after_update_batch = dataset
.scan()
.filter(&format!("key = {}", target_key))
.unwrap()
.with_row_id()
.try_into_batch()
.await
.unwrap();
let after_update_value = after_update_batch
.column_by_name("value")
.unwrap()
.as_primitive::<UInt32Type>()
.value(0);
if enable_stable_row_ids {
assert_eq!(
initial_row_id, after_update_row_id,
"Row ID should remain stable after update"
);
} else {
assert_ne!(
initial_row_id, after_update_row_id,
"Row ID should change after update when stable row IDs are disabled"
);
}
assert_eq!(
after_update_value,
initial_value + 100,
"Value should be updated correctly"
);
let merge_new_value = 500u32;
let new_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![target_key])),
Arc::new(UInt32Array::from(vec![merge_new_value])),
Arc::new(StringArray::from(vec!["UPDATED"])),
],
)
.unwrap();
let job = MergeInsertBuilder::try_new(dataset.clone(), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.try_build()
.unwrap();
let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
let new_stream = reader_to_stream(new_reader);
let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
let after_merge_row_ids = get_row_ids_for_keys(&merged_dataset, &target_keys).await;
let after_merge_row_id = after_merge_row_ids.value(0);
let after_merge_batch = merged_dataset
.scan()
.filter(&format!("key = {}", target_key))
.unwrap()
.with_row_id()
.try_into_batch()
.await
.unwrap();
let after_merge_value = after_merge_batch
.column_by_name("value")
.unwrap()
.as_primitive::<UInt32Type>()
.value(0);
let after_merge_filterme = after_merge_batch
.column_by_name("filterme")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.value(0);
if enable_stable_row_ids {
assert_eq!(
initial_row_id, after_merge_row_id,
"Row ID should remain stable after merge insert"
);
assert_eq!(
after_update_row_id, after_merge_row_id,
"Row ID should remain the same across update and merge insert"
);
} else {
assert_ne!(
after_update_row_id, after_merge_row_id,
"Row ID should change after merge insert when stable row IDs are disabled"
);
}
assert_eq!(
after_merge_value, merge_new_value,
"Value should be updated by merge insert"
);
assert_eq!(
after_merge_filterme, "UPDATED",
"Filterme should be updated by merge insert"
);
assert_eq!(
merge_stats.num_updated_rows, 1,
"Should update exactly 1 row"
);
assert_eq!(
merge_stats.num_inserted_rows, 0,
"Should not insert any new rows"
);
assert_eq!(
merge_stats.num_deleted_rows, 0,
"Should not delete any rows"
);
if enable_stable_row_ids {
assert_eq!(
initial_row_id, after_merge_row_id,
"Row ID should remain stable throughout the entire process of update and merge insert"
);
}
}
#[tokio::test]
async fn test_indexed_merge_insert() {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(1))
.col("value", array::step::<UInt32Type>())
.col("key", array::rand_pseudo_uuid_hex());
let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(32));
let schema = data.schema();
let mut ds = Dataset::write(data, test_uri, None).await.unwrap();
let index_params = ScalarIndexParams::default();
ds.create_index(&["key"], IndexType::Scalar, None, &index_params, false)
.await
.unwrap();
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(2))
.col("value", array::step::<UInt32Type>())
.col("key", array::rand_pseudo_uuid_hex());
let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(8));
let ds = Dataset::write(
data,
test_uri,
Some(WriteParams {
mode: WriteMode::Append,
..Default::default()
}),
)
.await
.unwrap();
let ds = Arc::new(ds);
let just_index_col = Schema::new(vec![Field::new("key", DataType::Utf8, false)]);
let some_indices = ds
.sample(2048, &(&just_index_col).try_into().unwrap())
.await
.unwrap();
let some_indices = some_indices.column(0).clone();
let some_vals = lance_datagen::gen_batch()
.anon_col(array::fill::<UInt32Type>(9999999))
.into_batch_rows(RowCount::from(2048))
.unwrap();
let some_vals = some_vals.column(0).clone();
let source_batch =
RecordBatch::try_new(schema.clone(), vec![some_vals, some_indices]).unwrap();
let source_batches = vec![
source_batch.slice(0, 512),
source_batch.slice(512, 512),
source_batch.slice(1024, 512),
source_batch.slice(1536, 512),
];
let source = Box::new(RecordBatchIterator::new(
source_batches.clone().into_iter().map(Ok),
schema.clone(),
));
let (ds, _) = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
.unwrap()
.when_not_matched(WhenNotMatched::DoNothing)
.when_matched(WhenMatched::UpdateAll)
.try_build()
.unwrap()
.execute_reader(source)
.await
.unwrap();
let updated = ds
.count_rows(Some("value = 9999999".to_string()))
.await
.unwrap();
assert_eq!(updated, 2048);
let source = Box::new(RecordBatchIterator::new(
source_batches.clone().into_iter().map(Ok),
schema.clone(),
));
let (ds, _) = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
.unwrap()
.when_not_matched(WhenNotMatched::DoNothing)
.when_matched(WhenMatched::UpdateAll)
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.try_build()
.unwrap()
.execute_reader(source)
.await
.unwrap();
assert_eq!(ds.count_rows(None).await.unwrap(), 2048);
let source = Box::new(RecordBatchIterator::new(
source_batches.clone().into_iter().map(Ok),
schema.clone(),
));
let (ds, _) = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
.unwrap()
.when_not_matched(WhenNotMatched::DoNothing)
.when_matched(WhenMatched::UpdateAll)
.try_build()
.unwrap()
.execute_reader(source)
.await
.unwrap();
assert_eq!(ds.count_rows(None).await.unwrap(), 2048);
}
mod subcols {
use super::*;
use rstest::rstest;
struct Fixtures {
ds: Arc<Dataset>,
new_data: RecordBatch,
}
async fn setup(scalar_index: bool) -> Fixtures {
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(1))
.col("other", array::rand_utf8(4.into(), false))
.col("value", array::step::<UInt32Type>())
.col("key", array::rand_pseudo_uuid_hex());
let batch = data.into_batch_rows(RowCount::from(1024 + 2)).unwrap();
let batch1 = batch.slice(0, 512);
let batch2 = batch.slice(512, 512);
let batch3 = batch.slice(1024, 2);
let schema = batch.schema();
let reader = Box::new(RecordBatchIterator::new(
[Ok(batch1.clone())],
schema.clone(),
));
let write_params = WriteParams {
max_rows_per_file: 256,
max_rows_per_group: 32, ..Default::default()
};
let mut ds = Dataset::write(reader, "memory://", Some(write_params.clone()))
.await
.unwrap();
if scalar_index {
let index_params = ScalarIndexParams::default();
ds.create_index(&["key"], IndexType::Scalar, None, &index_params, false)
.await
.unwrap();
}
let reader = Box::new(RecordBatchIterator::new(
[Ok(batch2.clone())],
batch2.schema(),
));
ds.append(reader, Some(write_params)).await.unwrap();
let ds = Arc::new(ds);
let update_schema = Arc::new(schema.project(&[2, 1]).unwrap());
let indices: Int64Array = (256..512).chain(600..612).chain([712, 715]).collect();
let keys = arrow::compute::take(batch["key"].as_ref(), &indices, None).unwrap();
let keys = arrow::compute::concat(&[&keys, &batch3["key"]]).unwrap();
let num_rows = keys.len();
let new_data = RecordBatch::try_new(
update_schema,
vec![
keys,
Arc::new((1024..(1024 + num_rows as u32)).collect::<UInt32Array>()),
],
)
.unwrap();
Fixtures { ds, new_data }
}
#[tokio::test]
async fn test_delete_not_supported() {
let Fixtures { ds, new_data } = Box::pin(setup(false)).await;
let reader = Box::new(RecordBatchIterator::new(
[Ok(new_data.clone())],
new_data.schema(),
));
let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
.unwrap()
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let res = assert_send(job.execute_reader(reader)).await;
assert!(
matches!(
&res,
&Err(Error::NotSupported { ref source, .. })
if source.to_string().contains("Deleting rows from the target table when there is no match in the source table is not supported when the source data has a different schema than the target data"),
),
"Expected NotSupported error, got: {:?}",
res
);
}
#[tokio::test]
async fn test_errors_on_bad_schema() {
let Fixtures { ds, new_data } = Box::pin(setup(false)).await;
let bad_schema = Arc::new(Schema::new(vec![
Field::new("wrong_key", DataType::Utf8, false),
Field::new("wrong_value", DataType::UInt32, false),
]));
let bad_batch =
RecordBatch::try_new(bad_schema.clone(), new_data.columns().to_vec()).unwrap();
let reader = Box::new(RecordBatchIterator::new([Ok(bad_batch)], bad_schema));
let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let res = job.execute_reader(reader).await;
assert!(
matches!(
&res,
&Err(Error::SchemaMismatch { ref difference, .. })
if difference.clone().contains("fields did not match")
),
"Expected SchemaMismatch error, got: {:?}",
res
);
}
#[rstest]
#[tokio::test]
async fn test_merge_insert_subcols(
#[values(false, true)] scalar_index: bool,
#[values(false, true)] insert: bool,
) {
let Fixtures { ds, new_data } = Box::pin(setup(scalar_index)).await;
let reader = Box::new(RecordBatchIterator::new(
[Ok(new_data.clone())],
new_data.schema(),
));
let fragments_before = ds
.get_fragments()
.iter()
.map(|f| f.metadata().clone())
.collect::<Vec<_>>();
let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(if insert {
WhenNotMatched::InsertAll
} else {
WhenNotMatched::DoNothing
})
.try_build()
.unwrap();
let (ds, stats) = job.execute_reader(reader).await.unwrap();
let fragments_after = ds
.get_fragments()
.iter()
.map(|f| f.metadata().clone())
.collect::<Vec<_>>();
assert_eq!(
fragments_before.iter().map(|f| f.id).collect::<Vec<_>>(),
fragments_after
.iter()
.take(fragments_before.len())
.map(|f| f.id)
.collect::<Vec<_>>()
);
assert_eq!(fragments_before[0], fragments_after[0]);
assert_ne!(fragments_before[1], fragments_after[1]);
assert_ne!(fragments_before[2], fragments_after[2]);
assert_eq!(fragments_before[3], fragments_after[3]);
let has_added_files = |frag: &Fragment| {
assert_eq!(frag.files.len(), 2);
let data_files = &frag.files;
assert_eq!(&data_files[0].fields, &[0, -2, -2]);
assert_eq!(&data_files[1].fields, &[2, 1]);
};
has_added_files(&fragments_after[1]);
has_added_files(&fragments_after[2]);
if insert {
assert_eq!(fragments_after.len(), 5);
assert_eq!(stats.num_inserted_rows, 2);
} else {
assert_eq!(fragments_after.len(), 4);
assert_eq!(stats.num_inserted_rows, 0);
}
assert_eq!(stats.num_updated_rows, (new_data.num_rows() - 2) as u64);
assert_eq!(stats.num_deleted_rows, 0);
let data = ds
.scan()
.scan_in_order(true)
.try_into_batch()
.await
.unwrap();
assert_eq!(data.num_rows(), if insert { 1024 + 2 } else { 1024 });
assert_eq!(data.num_columns(), 3);
let values = data
.column(1)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
assert_eq!(values.value(0), 0);
assert_eq!(values.value(256), 1024);
assert_eq!(values.value(512), 512);
assert_eq!(values.value(715), 1024 + new_data.num_rows() as u32 - 3);
if insert {
assert_eq!(values.value(1024), 1024 + new_data.num_rows() as u32 - 2);
}
}
}
#[cfg(not(windows))]
#[rstest::rstest]
#[case::all_success(Duration::from_secs(100_000))]
#[case::timeout(Duration::from_millis(200))]
#[tokio::test]
async fn test_merge_insert_concurrency(#[case] timeout: Duration) {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("value", DataType::UInt32, false),
]));
let concurrency = 10;
let initial_data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from_iter_values(0..concurrency)),
Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n(
0,
concurrency as usize,
))),
],
)
.unwrap();
let throttled = Arc::new(ThrottledStoreWrapper {
config: ThrottleConfig {
wait_list_per_call: Duration::from_millis(20),
wait_get_per_call: Duration::from_millis(20),
wait_put_per_call: Duration::from_millis(20),
..Default::default()
},
});
let session = Arc::new(Session::default());
let mut dataset = InsertBuilder::new("memory://")
.with_params(&WriteParams {
store_params: Some(ObjectStoreParams {
object_store_wrapper: Some(throttled.clone()),
..Default::default()
}),
session: Some(session.clone()),
..Default::default()
})
.execute(vec![initial_data])
.await
.unwrap();
let barrier = Arc::new(Barrier::new(concurrency as usize));
let mut handles = Vec::new();
for i in 0..concurrency {
let session_ref = session.clone();
let schema_ref = schema.clone();
let barrier_ref = barrier.clone();
let throttled_ref = throttled.clone();
let handle = tokio::task::spawn(async move {
let dataset = DatasetBuilder::from_uri("memory://")
.with_read_params(ReadParams {
store_options: Some(ObjectStoreParams {
object_store_wrapper: Some(throttled_ref.clone()),
..Default::default()
}),
session: Some(session_ref.clone()),
..Default::default()
})
.load()
.await
.unwrap();
let dataset = Arc::new(dataset);
let new_data = RecordBatch::try_new(
schema_ref.clone(),
vec![
Arc::new(UInt32Array::from(vec![i])),
Arc::new(UInt32Array::from(vec![1])),
],
)
.unwrap();
let source = Box::new(RecordBatchIterator::new([Ok(new_data)], schema_ref.clone()));
let job = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.conflict_retries(100)
.retry_timeout(timeout)
.try_build()
.unwrap();
barrier_ref.wait().await;
job.execute_reader(source)
.await
.map(|(_ds, stats)| stats.num_attempts)
});
handles.push(handle);
}
let results = try_join_all(handles).await.unwrap();
for attempts in results.iter() {
match attempts {
Ok(attempts) => {
assert!(*attempts <= 10, "Attempt count should be <= 10");
}
Err(err) => {
assert!(
matches!(err, Error::TooMuchWriteContention { message, .. } if message.contains("failed on retry_timeout")),
"Expected TooMuchWriteContention error, got: {:?}",
err
);
}
}
}
if timeout.as_secs() > 10 {
dataset.checkout_latest().await.unwrap();
let batches = dataset.scan().try_into_batch().await.unwrap();
let values = batches["value"].as_primitive::<UInt32Type>();
assert!(
values.values().iter().all(|&v| v == 1),
"All values should be 1 after merge insert. Got: {:?}",
values
);
}
}
#[tokio::test]
async fn test_merge_insert_large_concurrent() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("value", DataType::UInt32, false),
]));
let num_rows = 10;
let initial_data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from_iter_values(0..num_rows)),
Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n(
0,
num_rows as usize,
))),
],
)
.unwrap();
let throttled = Arc::new(ThrottledStoreWrapper {
config: ThrottleConfig {
wait_list_per_call: Duration::from_millis(10),
wait_get_per_call: Duration::from_millis(10),
..Default::default()
},
});
let session = Arc::new(Session::default());
let dataset = InsertBuilder::new("memory://")
.with_params(&WriteParams {
store_params: Some(ObjectStoreParams {
object_store_wrapper: Some(throttled.clone()),
..Default::default()
}),
session: Some(session.clone()),
..Default::default()
})
.execute(vec![initial_data])
.await
.unwrap();
let dataset = Arc::new(dataset);
let new_data1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![1])),
Arc::new(UInt32Array::from(vec![1])),
],
)
.unwrap();
let UncommittedMergeInsert {
transaction: transaction1,
..
} = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute_uncommitted(RecordBatchIterator::new(
vec![Ok(new_data1)],
schema.clone(),
))
.await
.unwrap();
let new_data2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from_iter_values(0..1000)),
Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n(2, 1000))),
],
)
.unwrap();
let notify = Arc::new(Notify::new());
let source = RecordBatchIterator::new(
(0..10)
.map(|i| {
let batch = new_data2.slice(i * 100, 100);
if i == 9 {
notify.notify_one();
}
Ok(batch)
})
.collect::<Vec<_>>(),
schema.clone(),
);
let dataset2 = DatasetBuilder::from_uri("memory://")
.with_read_params(ReadParams {
store_options: Some(ObjectStoreParams {
object_store_wrapper: Some(throttled.clone()),
..Default::default()
}),
session: Some(session.clone()),
..Default::default()
})
.load()
.await
.unwrap();
let job = MergeInsertBuilder::try_new(Arc::new(dataset2), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute_reader(source);
let task = tokio::task::spawn(job);
notify.notified().await;
let mut dataset = CommitBuilder::new(dataset)
.execute(transaction1)
.await
.unwrap();
task.await.unwrap().unwrap();
dataset.checkout_latest().await.unwrap();
let batches = dataset.scan().try_into_batch().await.unwrap();
let values = batches["value"].as_primitive::<UInt32Type>();
assert!(
values.values().iter().all(|&v| v == 2),
"All values should be 1 after merge insert. Got: {:?}",
values
);
}
#[tokio::test]
async fn test_merge_insert_updates_indices() {
let test_dataset = async || {
let mut dataset = lance_datagen::gen_batch()
.col("id", array::step::<UInt32Type>())
.col("value", array::step::<UInt32Type>())
.col("other_value", array::step::<UInt32Type>())
.into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(20))
.await
.unwrap();
dataset
.create_index(
&["id"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
false,
)
.await
.unwrap();
dataset
.create_index(
&["value"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
false,
)
.await
.unwrap();
dataset
.create_index(
&["other_value"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
false,
)
.await
.unwrap();
Arc::new(dataset)
};
let check_indices = async |dataset: &Dataset, id_frags: &[u32], value_frags: &[u32]| {
let id_index = dataset
.load_scalar_index(IndexCriteria::default().with_name("id_idx"))
.await
.unwrap();
if id_frags.is_empty() {
assert!(id_index.is_none());
} else {
let id_index = id_index.unwrap();
let id_frags_bitmap = RoaringBitmap::from_iter(id_frags.iter().copied());
let effective_bitmap = id_index
.effective_fragment_bitmap(&dataset.fragment_bitmap)
.unwrap();
assert_eq!(effective_bitmap, id_frags_bitmap);
}
let value_index = dataset
.load_scalar_index(IndexCriteria::default().with_name("value_idx"))
.await
.unwrap();
if value_frags.is_empty() {
assert!(value_index.is_none());
} else {
let value_index = value_index.unwrap();
let value_frags_bitmap = RoaringBitmap::from_iter(value_frags.iter().copied());
let effective_bitmap = value_index
.effective_fragment_bitmap(&dataset.fragment_bitmap)
.unwrap();
assert_eq!(effective_bitmap, value_frags_bitmap);
}
let other_value_index = dataset
.load_scalar_index(IndexCriteria::default().with_name("other_value_idx"))
.await
.unwrap()
.unwrap();
let effective_bitmap = other_value_index
.effective_fragment_bitmap(&dataset.fragment_bitmap)
.unwrap();
let index_bitmap = other_value_index.fragment_bitmap.as_ref().unwrap();
let expected_bitmap = index_bitmap & dataset.fragment_bitmap.as_ref();
assert_eq!(
effective_bitmap, expected_bitmap,
"other_value index effective bitmap should be intersection. index_bitmap: {:?}, dataset_fragments: {:?}, effective_bitmap: {:?}",
index_bitmap, dataset.fragment_bitmap, effective_bitmap
);
};
let dataset = test_dataset().await;
check_indices(&dataset, &[0, 1, 2, 3], &[0, 1, 2, 3]).await;
let merge_insert = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let (dataset, _) = merge_insert
.execute_reader(
lance_datagen::gen_batch()
.col("id", array::step_custom::<UInt32Type>(50, 1))
.col("value", array::step_custom::<UInt32Type>(50, 1))
.col("other_value", array::step_custom::<UInt32Type>(50, 1))
.into_df_stream(RowCount::from(40), BatchCount::from(1)),
)
.await
.unwrap();
check_indices(&dataset, &[0, 1, 2], &[0, 1, 2]).await;
let dataset = test_dataset().await;
let merge_insert = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let (dataset, _) = merge_insert
.execute_reader(
lance_datagen::gen_batch()
.col("id", array::step_custom::<UInt32Type>(50, 1))
.col("value", array::step_custom::<UInt32Type>(50, 1))
.into_df_stream(RowCount::from(40), BatchCount::from(1)),
)
.await
.unwrap();
check_indices(&dataset, &[0, 1], &[0, 1]).await;
let dataset = test_dataset().await;
let merge_insert = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let (dataset, _) = merge_insert
.execute_reader(
lance_datagen::gen_batch()
.col("id", array::step_custom::<UInt32Type>(10, 1))
.col("value", array::step_custom::<UInt32Type>(10, 1))
.into_df_stream(RowCount::from(80), BatchCount::from(1)),
)
.await
.unwrap();
check_indices(&dataset, &[], &[]).await;
}
#[tokio::test]
async fn test_upsert_concurrent_full_frag() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("value", DataType::UInt32, false),
]));
let initial_data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![0, 1])),
Arc::new(UInt32Array::from(vec![0, 0])),
],
)
.unwrap();
let throttled = Arc::new(ThrottledStoreWrapper {
config: ThrottleConfig {
wait_list_per_call: Duration::from_millis(5),
wait_get_per_call: Duration::from_millis(5),
wait_put_per_call: Duration::from_millis(5),
..Default::default()
},
});
let session = Arc::new(Session::default());
let mut dataset = InsertBuilder::new("memory://")
.with_params(&WriteParams {
store_params: Some(ObjectStoreParams {
object_store_wrapper: Some(throttled.clone()),
..Default::default()
}),
session: Some(session.clone()),
..Default::default()
})
.execute(vec![initial_data])
.await
.unwrap();
let barrier = Arc::new(Barrier::new(2));
let mut handles = Vec::new();
for i in 0..2 {
let new_data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![i])),
Arc::new(UInt32Array::from(vec![1])),
],
)
.unwrap();
let source = Box::new(RecordBatchIterator::new([Ok(new_data)], schema.clone()));
let dataset_ref = Arc::new(dataset.clone());
let barrier = barrier.clone();
let handle = tokio::spawn(async move {
barrier.wait().await;
MergeInsertBuilder::try_new(dataset_ref, vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute_reader(source)
.await
.unwrap();
});
handles.push(handle);
}
try_join_all(handles).await.unwrap();
dataset.checkout_latest().await.unwrap();
assert!(
dataset
.get_fragments()
.iter()
.all(|f| f.metadata().num_rows().unwrap() > 0),
"No fragments should have zero rows after upsert"
);
let batches = dataset.scan().try_into_batch().await.unwrap();
let values = batches["value"].as_primitive::<UInt32Type>();
assert!(
values.values().iter().all(|&v| v == 1),
"All values should be 1 after merge insert. Got: {:?}",
values
);
}
#[tokio::test]
async fn test_plan_upsert() {
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(1))
.col("value", array::step::<UInt32Type>())
.col("key", array::rand_pseudo_uuid_hex());
let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(32));
let _schema = data.schema();
let ds = Dataset::write(data, "memory://", None).await.unwrap();
let merge_insert_job =
crate::dataset::MergeInsertBuilder::try_new(Arc::new(ds), vec!["key".to_string()])
.unwrap()
.when_matched(crate::dataset::WhenMatched::UpdateAll)
.try_build()
.unwrap();
let new_data = lance_datagen::gen_batch()
.with_seed(Seed::from(2))
.col("value", array::step::<UInt32Type>())
.col("key", array::rand_pseudo_uuid_hex());
let new_data = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16));
let new_data_stream = reader_to_stream(Box::new(new_data));
let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap();
assert_plan_node_equals(
plan,
"MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@1 as _rowid, _rowaddr@2 as _rowaddr, value@3 as value, key@4 as key, CASE WHEN __common_expr_1@0 AND _rowaddr@2 IS NULL THEN 2 WHEN __common_expr_1@0 AND _rowaddr@2 IS NOT NULL THEN 1 ELSE 0 END as __action]
ProjectionExec: expr=[key@3 IS NOT NULL as __common_expr_1, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key]
HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4]
CooperativeExec
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, \
row_id=true, row_addr=true, full_filter=--, refine_filter=--
RepartitionExec: partitioning=RoundRobinBatch(...), input_partitions=1
StreamingTableExec: partition_sizes=1, projection=[value, key]"
).await.unwrap();
}
#[tokio::test]
async fn test_fast_path_update_only() {
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(1))
.col("value", array::step::<UInt32Type>())
.col("key", array::rand_pseudo_uuid_hex());
let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(32));
let ds = Dataset::write(data, "memory://", None).await.unwrap();
let merge_insert_job =
crate::dataset::MergeInsertBuilder::try_new(Arc::new(ds), vec!["key".to_string()])
.unwrap()
.when_matched(crate::dataset::WhenMatched::UpdateAll)
.when_not_matched(crate::dataset::WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let new_data = lance_datagen::gen_batch()
.with_seed(Seed::from(2))
.col("value", array::step::<UInt32Type>())
.col("key", array::rand_pseudo_uuid_hex());
let new_data = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16));
let new_data_stream = reader_to_stream(Box::new(new_data));
let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap();
assert_plan_node_equals(
plan,
"MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=DoNothing, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, CASE WHEN key@3 IS NOT NULL AND _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4]
CooperativeExec
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
RepartitionExec...
StreamingTableExec: partition_sizes=1, projection=[value, key]"
).await.unwrap();
}
#[tokio::test]
async fn test_fast_path_conditional_update() {
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(1))
.col("value", array::step::<UInt32Type>())
.col("key", array::rand_pseudo_uuid_hex());
let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(32));
let ds = Dataset::write(data, "memory://", None).await.unwrap();
let merge_insert_job = crate::dataset::MergeInsertBuilder::try_new(
Arc::new(ds.clone()),
vec!["key".to_string()],
)
.unwrap()
.when_matched(crate::dataset::WhenMatched::update_if(&ds, "source.value > 20").unwrap())
.when_not_matched(crate::dataset::WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let new_data = lance_datagen::gen_batch()
.with_seed(Seed::from(2))
.col("value", array::step::<UInt32Type>())
.col("key", array::rand_pseudo_uuid_hex());
let new_data_reader = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16));
let new_data_stream = reader_to_stream(Box::new(new_data_reader));
let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap();
assert_plan_node_equals(
plan,
"MergeInsert: on=[key], when_matched=UpdateIf(source.value > 20), when_not_matched=DoNothing, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, CASE WHEN key@3 IS NOT NULL AND _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4]
CooperativeExec
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
RepartitionExec...
StreamingTableExec: partition_sizes=1, projection=[value, key]"
).await.unwrap();
}
#[tokio::test]
async fn test_skip_auto_cleanup() {
let tmpdir = TempStrDir::default();
let dataset_uri = format!("{}/{}", tmpdir, "test_dataset");
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(1))
.col("id", array::step::<UInt32Type>())
.into_reader_rows(RowCount::from(100), BatchCount::from(1));
let mut auto_cleanup_params = HashMap::new();
auto_cleanup_params.insert("lance.auto_cleanup.interval".to_string(), "1".to_string());
auto_cleanup_params.insert(
"lance.auto_cleanup.older_than".to_string(),
"0ms".to_string(),
);
let write_params = WriteParams {
mode: WriteMode::Create,
auto_cleanup: Some(crate::dataset::AutoCleanupParams {
interval: 1,
older_than: chrono::TimeDelta::try_milliseconds(0).unwrap(),
}),
..Default::default()
};
MockClock::set_system_time(std::time::Duration::from_secs(1));
let dataset = Dataset::write(data, &dataset_uri, Some(write_params))
.await
.unwrap();
assert_eq!(dataset.version().version, 1);
MockClock::set_system_time(std::time::Duration::from_secs(2));
let new_data = lance_datagen::gen_batch()
.with_seed(Seed::from(2))
.col("id", array::step::<UInt32Type>())
.into_df_stream(RowCount::from(50), BatchCount::from(1));
let (dataset2, _) = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute(new_data)
.await
.unwrap();
assert_eq!(dataset2.version().version, 2);
MockClock::set_system_time(std::time::Duration::from_secs(3));
let new_data_extra = lance_datagen::gen_batch()
.with_seed(Seed::from(4))
.col("id", array::step::<UInt32Type>())
.into_df_stream(RowCount::from(10), BatchCount::from(1));
let (dataset2_extra, _) =
MergeInsertBuilder::try_new(dataset2.clone(), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute(new_data_extra)
.await
.unwrap();
assert_eq!(dataset2_extra.version().version, 3);
let ds_check1 = DatasetBuilder::from_uri(&dataset_uri).load().await.unwrap();
assert!(
ds_check1.checkout_version(1).await.is_err(),
"Version 1 should have been cleaned up"
);
assert!(
ds_check1.checkout_version(2).await.is_ok(),
"Version 2 should still exist"
);
MockClock::set_system_time(std::time::Duration::from_secs(4));
let new_data2 = lance_datagen::gen_batch()
.with_seed(Seed::from(3))
.col("id", array::step::<UInt32Type>())
.into_df_stream(RowCount::from(30), BatchCount::from(1));
let (dataset3, _) = MergeInsertBuilder::try_new(dataset2_extra, vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.skip_auto_cleanup(true) .try_build()
.unwrap()
.execute(new_data2)
.await
.unwrap();
assert_eq!(dataset3.version().version, 4);
let ds_check2 = DatasetBuilder::from_uri(&dataset_uri).load().await.unwrap();
assert!(
ds_check2.checkout_version(2).await.is_ok(),
"Version 2 should still exist because skip_auto_cleanup was enabled"
);
assert!(
ds_check2.checkout_version(3).await.is_ok(),
"Version 3 should still exist"
);
}
#[tokio::test]
async fn test_transaction_inserted_rows_filter_roundtrip() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt32, false).with_metadata(
vec![(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
)]
.into_iter()
.collect(),
),
Field::new("value", DataType::UInt32, false),
]));
let initial = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![0, 1, 2])),
Arc::new(UInt32Array::from(vec![0, 0, 0])),
],
)
.unwrap();
let dataset = InsertBuilder::new("memory://")
.execute(vec![initial])
.await
.unwrap();
let dataset = Arc::new(dataset);
let new_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![1, 3])),
Arc::new(UInt32Array::from(vec![2, 2])),
],
)
.unwrap();
let stream = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(vec![Ok(new_batch)]),
);
let UncommittedMergeInsert { transaction, .. } =
MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute_uncommitted(Box::pin(stream) as SendableRecordBatchStream)
.await
.unwrap();
let committed = CommitBuilder::new(dataset.clone())
.execute(transaction)
.await
.unwrap();
let tx_path = committed.manifest().transaction_file.clone().unwrap();
let tx_read = read_transaction_file(dataset.object_store(), &dataset.base, &tx_path)
.await
.unwrap();
if let Operation::Update {
inserted_rows_filter,
..
} = &tx_read.operation
{
assert!(inserted_rows_filter.is_some());
let filter = inserted_rows_filter.as_ref().unwrap();
assert_eq!(filter.field_ids.len(), 1);
} else {
panic!("Expected Operation::Update");
}
}
#[tokio::test]
async fn test_inserted_rows_filter_bloom_conflict_detection_concurrent() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt32, false).with_metadata(
vec![(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
)]
.into_iter()
.collect(),
),
Field::new("value", DataType::UInt32, false),
]));
let initial = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![0, 1, 2, 3])),
Arc::new(UInt32Array::from(vec![0, 0, 0, 0])),
],
)
.unwrap();
let dataset = InsertBuilder::new("memory://")
.execute(vec![initial])
.await
.unwrap();
let dataset = Arc::new(dataset);
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![2])),
Arc::new(UInt32Array::from(vec![1])),
],
)
.unwrap();
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![2])),
Arc::new(UInt32Array::from(vec![2])),
],
)
.unwrap();
let b2 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.conflict_retries(0)
.try_build()
.unwrap();
let s1 = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(vec![Ok(batch1.clone())]),
);
let b1 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let result1 = b1.execute(Box::pin(s1) as SendableRecordBatchStream).await;
assert!(result1.is_ok(), "First merge insert should succeed");
let s2 = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(vec![Ok(batch2.clone())]),
);
let result2 = b2.execute(Box::pin(s2) as SendableRecordBatchStream).await;
assert!(
matches!(result2, Err(crate::Error::TooMuchWriteContention { .. })),
"Expected TooMuchWriteContention (retryable conflict exhausted), got: {:?}",
result2
);
}
#[tokio::test]
async fn test_concurrent_insert_same_new_key() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt32, false).with_metadata(
vec![(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
)]
.into_iter()
.collect(),
),
Field::new("value", DataType::UInt32, false),
]));
let initial = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![0, 1, 2, 3])),
Arc::new(UInt32Array::from(vec![0, 0, 0, 0])),
],
)
.unwrap();
let dataset = InsertBuilder::new("memory://")
.execute(vec![initial])
.await
.unwrap();
let dataset = Arc::new(dataset);
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![100])), Arc::new(UInt32Array::from(vec![1])),
],
)
.unwrap();
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![100])), Arc::new(UInt32Array::from(vec![2])),
],
)
.unwrap();
let b2 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.conflict_retries(0)
.try_build()
.unwrap();
let s1 = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(vec![Ok(batch1.clone())]),
);
let b1 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let result1 = b1.execute(Box::pin(s1) as SendableRecordBatchStream).await;
assert!(result1.is_ok(), "First merge insert should succeed");
let s2 = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(vec![Ok(batch2.clone())]),
);
let result2 = b2.execute(Box::pin(s2) as SendableRecordBatchStream).await;
assert!(
matches!(result2, Err(crate::Error::TooMuchWriteContention { .. })),
"Expected TooMuchWriteContention (retryable conflict exhausted), got: {:?}",
result2
);
}
#[test]
fn test_concurrent_insert_different_new_list_key() {
let tags_field = Field::new(
"tags",
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
false,
);
let schema = Arc::new(Schema::new(vec![tags_field]));
let mut builder = ListBuilder::new(StringBuilder::new());
builder.append_value(["a", "b"].iter().copied().map(Some));
let tags_array1 = builder.finish();
let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(tags_array1)]).unwrap();
let mut builder = ListBuilder::new(StringBuilder::new());
builder.append_value(["c", "d"].iter().copied().map(Some));
let tags_array2 = builder.finish();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(tags_array2)]).unwrap();
let field_ids = vec![0_i32];
let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("tags")])
.expect("first batch should produce key");
let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("tags")])
.expect("second batch should produce key");
builder1.insert(key1).unwrap();
builder2.insert(key2).unwrap();
let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
assert!(
!has_intersection,
"Expected bloom filters not intersect for different list(string) keys",
);
assert!(
!might_be_fp,
"Bloom filter intersection should be definitively not conflict",
);
}
#[test]
fn test_concurrent_insert_same_new_list_key() {
let tags_field = Field::new(
"tags",
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
false,
);
let schema = Arc::new(Schema::new(vec![tags_field]));
let mut builder = ListBuilder::new(StringBuilder::new());
builder.append_value(["a", "b"].iter().copied().map(Some));
let tags_array1 = builder.finish();
let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(tags_array1)]).unwrap();
let mut builder = ListBuilder::new(StringBuilder::new());
builder.append_value(["a", "b"].iter().copied().map(Some));
let tags_array2 = builder.finish();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(tags_array2)]).unwrap();
let field_ids = vec![0_i32];
let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("tags")])
.expect("first batch should produce key");
let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("tags")])
.expect("second batch should produce key");
builder1.insert(key1).unwrap();
builder2.insert(key2).unwrap();
let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
assert!(
has_intersection,
"Expected bloom filters to intersect for identical list(string) keys",
);
assert!(
might_be_fp,
"Bloom filter intersection should be treated as potential conflict",
);
}
#[test]
fn test_concurrent_insert_same_new_nested_list_key() {
let nested_tags = make_nested_array(&[["a", "b"].as_slice(), ["c"].as_slice()]);
let tags_field = Field::new("tags", nested_tags.data_type().clone(), false);
let nested_tags2 = make_nested_array(&[["a", "b"].as_slice(), ["c"].as_slice()]);
let schema = Arc::new(Schema::new(vec![tags_field]));
let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(nested_tags)]).unwrap();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(nested_tags2)]).unwrap();
let field_ids = vec![0_i32];
let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("tags")])
.expect("first batch should produce key");
let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("tags")])
.expect("second batch should produce key");
builder1.insert(key1).unwrap();
builder2.insert(key2).unwrap();
let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
assert!(
has_intersection,
"Expected bloom filters to intersect for identical nested list(list(string)) keys",
);
assert!(
might_be_fp,
"Bloom filter intersection should be treated as potential conflict",
);
}
#[test]
fn test_concurrent_insert_different_new_struct_key() {
let user_field = Field::new(
"user",
DataType::Struct(
vec![
Field::new("first", DataType::Utf8, false),
Field::new("last", DataType::Utf8, false),
]
.into(),
),
false,
);
let schema = Arc::new(Schema::new(vec![user_field]));
let struct_array1 = make_struct_array_first_last_name(vec!["alice"], vec!["smith"]);
let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(struct_array1)]).unwrap();
let struct_array2 = make_struct_array_first_last_name(vec!["bob"], vec!["jones"]);
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(struct_array2)]).unwrap();
let field_ids = vec![0_i32];
let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("user")])
.expect("first batch should produce key");
let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("user")])
.expect("second batch should produce key");
builder1.insert(key1).unwrap();
builder2.insert(key2).unwrap();
let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
assert!(
!has_intersection,
"Expected bloom filters not intersect for different struct keys",
);
assert!(
!might_be_fp,
"Bloom filter intersection should be definitively not conflict",
);
}
#[test]
fn test_concurrent_insert_same_new_struct_key() {
let user_field = Field::new(
"user",
DataType::Struct(
vec![
Field::new("first", DataType::Utf8, false),
Field::new("last", DataType::Utf8, false),
]
.into(),
),
false,
);
let schema = Arc::new(Schema::new(vec![user_field]));
let struct_array1 = make_struct_array_first_last_name(vec!["alice"], vec!["smith"]);
let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(struct_array1)]).unwrap();
let struct_array2 = make_struct_array_first_last_name(vec!["alice"], vec!["smith"]);
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(struct_array2)]).unwrap();
let field_ids = vec![0_i32];
let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("user")])
.expect("first batch should produce key");
let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("user")])
.expect("second batch should produce key");
builder1.insert(key1).unwrap();
builder2.insert(key2).unwrap();
let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
assert!(
has_intersection,
"Expected bloom filters to intersect for identical struct keys",
);
assert!(
might_be_fp,
"Bloom filter intersection should be treated as potential conflict",
);
}
#[test]
fn test_concurrent_insert_same_new_nested_struct_key() {
let outer_struct = make_nested_struct_array_city_zip("seattle", 98101);
let user_field = Field::new("user", outer_struct.data_type().clone(), false);
let schema = Arc::new(Schema::new(vec![user_field]));
let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(outer_struct)]).unwrap();
let outer_struct2 = make_nested_struct_array_city_zip("seattle", 98101);
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(outer_struct2)]).unwrap();
let field_ids = vec![0_i32];
let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("user")])
.expect("first batch should produce key");
let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("user")])
.expect("second batch should produce key");
builder1.insert(key1).unwrap();
builder2.insert(key2).unwrap();
let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
assert!(
has_intersection,
"Expected bloom filters to intersect for identical nested struct keys",
);
assert!(
might_be_fp,
"Bloom filter intersection should be treated as potential conflict",
);
}
#[tokio::test]
async fn test_merge_insert_struct_key_upsert() {
let user_field = Field::new(
"user",
DataType::Struct(
vec![
Field::new("first", DataType::Utf8, false),
Field::new("last", DataType::Utf8, false),
]
.into(),
),
false,
);
let schema = Arc::new(Schema::new(vec![
user_field,
Field::new("value", DataType::UInt32, false),
]));
let user_array = make_struct_array_first_last_name(
vec!["alice", "bob", "carla"],
vec!["smith", "jones", "doe"],
);
let values = UInt32Array::from(vec![1, 1, 1]);
let initial_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(user_array), Arc::new(values)])
.unwrap();
let test_uri = "memory://test_merge_insert_struct_key.lance";
let dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone()),
test_uri,
None,
)
.await
.unwrap();
let dataset = Arc::new(dataset);
let new_user_array =
make_struct_array_first_last_name(vec!["alice", "david"], vec!["smith", "brown"]);
let new_values = UInt32Array::from(vec![10, 2]);
let new_batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(new_user_array), Arc::new(new_values)],
)
.unwrap();
let reader = RecordBatchIterator::new([Ok(new_batch)], schema.clone());
let (merged_ds, stats) = MergeInsertBuilder::try_new(dataset, vec!["user".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute(reader_to_stream(Box::new(reader)))
.await
.unwrap();
assert_eq!(stats.num_updated_rows, 1);
assert_eq!(stats.num_inserted_rows, 1);
assert_eq!(stats.num_deleted_rows, 0);
let result = merged_ds.scan().try_into_batch().await.unwrap();
let user_col = result
.column_by_name("user")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let first = user_col
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let last = user_col
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let values = result
.column_by_name("value")
.unwrap()
.as_primitive::<UInt32Type>();
let mut rows = Vec::new();
for i in 0..result.num_rows() {
rows.push((
first.value(i).to_string(),
last.value(i).to_string(),
values.value(i),
));
}
rows.sort();
assert_eq!(
rows,
vec![
("alice".to_string(), "smith".to_string(), 10),
("bob".to_string(), "jones".to_string(), 1),
("carla".to_string(), "doe".to_string(), 1),
("david".to_string(), "brown".to_string(), 2),
],
);
}
fn make_struct_array_first_last_name(first: Vec<&str>, last: Vec<&str>) -> StructArray {
let first = StringArray::from(first);
let last = StringArray::from(last);
StructArray::from(vec![
(
Arc::new(Field::new("first", DataType::Utf8, false)),
Arc::new(first) as Arc<dyn Array>,
),
(
Arc::new(Field::new("last", DataType::Utf8, false)),
Arc::new(last) as Arc<dyn Array>,
),
])
}
fn make_nested_struct_array_city_zip(city: &str, zip: i32) -> StructArray {
let city = StringArray::from(vec![city]);
let zip = Int32Array::from(vec![zip]);
let inner_struct = StructArray::from(vec![
(
Arc::new(Field::new("city", DataType::Utf8, false)),
Arc::new(city) as Arc<dyn Array>,
),
(
Arc::new(Field::new("zip", DataType::Int32, false)),
Arc::new(zip) as Arc<dyn Array>,
),
]);
StructArray::from(vec![(
Arc::new(Field::new(
"address",
inner_struct.data_type().clone(),
false,
)),
Arc::new(inner_struct) as Arc<dyn Array>,
)])
}
fn make_nested_array(inner_lists: &[&[&str]]) -> ListArray {
let mut inner_builder = ListBuilder::new(StringBuilder::new());
for inner in inner_lists {
inner_builder.append_value(inner.iter().map(|s| Some(*s)));
}
let inner_list_array = inner_builder.finish();
let offsets = ScalarBuffer::<i32>::from(vec![0, inner_list_array.len() as i32]);
let offsets = OffsetBuffer::new(offsets);
ListArray::new(
Arc::new(Field::new(
"item",
inner_list_array.data_type().clone(),
inner_list_array.nulls().is_some(),
)),
offsets,
Arc::new(inner_list_array),
None,
)
}
#[tokio::test]
async fn test_merge_insert_conflict_with_update_without_filter() {
use crate::dataset::UpdateBuilder;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt32, false).with_metadata(
vec![(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
)]
.into_iter()
.collect(),
),
Field::new("value", DataType::UInt32, false),
]));
let initial = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![0, 1, 2, 3])),
Arc::new(UInt32Array::from(vec![0, 0, 0, 0])),
],
)
.unwrap();
let dataset = InsertBuilder::new("memory://")
.execute(vec![initial])
.await
.unwrap();
let dataset = Arc::new(dataset);
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![100])),
Arc::new(UInt32Array::from(vec![1])),
],
)
.unwrap();
let b1 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.conflict_retries(0)
.try_build()
.unwrap();
let update_result = UpdateBuilder::new(dataset.clone())
.update_where("id = 0")
.unwrap()
.set("value", "999")
.unwrap()
.build()
.unwrap()
.execute()
.await;
assert!(update_result.is_ok(), "Update should succeed");
let s1 = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(vec![Ok(batch1.clone())]),
);
let merge_result = b1.execute(Box::pin(s1) as SendableRecordBatchStream).await;
assert!(
matches!(
merge_result,
Err(crate::Error::TooMuchWriteContention { .. })
),
"Expected TooMuchWriteContention (retryable conflict exhausted), got: {:?}",
merge_result
);
}
#[tokio::test]
async fn test_merge_insert_conflict_with_append() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt32, false).with_metadata(
vec![(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
)]
.into_iter()
.collect(),
),
Field::new("value", DataType::UInt32, false),
]));
let initial = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![0, 1, 2, 3])),
Arc::new(UInt32Array::from(vec![0, 0, 0, 0])),
],
)
.unwrap();
let dataset = InsertBuilder::new("memory://")
.execute(vec![initial])
.await
.unwrap();
let dataset = Arc::new(dataset);
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![100])),
Arc::new(UInt32Array::from(vec![1])),
],
)
.unwrap();
let b1 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.conflict_retries(0)
.try_build()
.unwrap();
let append_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![50])),
Arc::new(UInt32Array::from(vec![2])),
],
)
.unwrap();
let append_result = InsertBuilder::new(dataset.clone())
.with_params(&WriteParams {
mode: WriteMode::Append,
..Default::default()
})
.execute(vec![append_batch])
.await;
assert!(append_result.is_ok(), "Append should succeed");
let s1 = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(vec![Ok(batch1.clone())]),
);
let merge_result = b1.execute(Box::pin(s1) as SendableRecordBatchStream).await;
assert!(
matches!(
merge_result,
Err(crate::Error::TooMuchWriteContention { .. })
),
"Expected TooMuchWriteContention (retryable conflict exhausted), got: {:?}",
merge_result
);
}
#[tokio::test]
async fn test_explain_plan() {
let dataset = lance_datagen::gen_batch()
.col("id", lance_datagen::array::step::<Int32Type>())
.col("name", array::cycle_utf8_literals(&["a", "b", "c"]))
.into_ram_dataset(FragmentCount::from(1), FragmentRowCount::from(3))
.await
.unwrap();
let merge_insert_job =
MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let plan = merge_insert_job.explain_plan(None, false).await.unwrap();
let expected_pattern = "\
MergeInsert: on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep...
CoalescePartitionsExec...
HashJoinExec...
LanceRead...
StreamingTableExec: partition_sizes=1, projection=[id, name]";
assert_string_matches(&plan, expected_pattern).unwrap();
let source_schema = arrow_schema::Schema::from(dataset.schema());
let explicit_plan = merge_insert_job
.explain_plan(Some(&source_schema), false)
.await
.unwrap();
assert_eq!(plan, explicit_plan);
let verbose_plan = merge_insert_job.explain_plan(None, true).await.unwrap();
assert!(verbose_plan.contains("MergeInsert"));
assert_string_matches(&verbose_plan, expected_pattern).unwrap();
}
#[tokio::test]
async fn test_explain_plan_full_schema_delete_by_source_with_fsl() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"vec",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
true,
),
]));
let dataset_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(
FixedSizeListArray::try_new_from_values(
Float32Array::from(vec![
1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3,
]),
4,
)
.unwrap(),
),
],
)
.unwrap();
let dataset = Dataset::write(
Box::new(RecordBatchIterator::new(
[Ok(dataset_batch)],
schema.clone(),
)),
"memory://test_explain_plan_full_schema_delete_by_source_with_fsl",
None,
)
.await
.unwrap();
let merge_insert_job =
MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.use_index(false)
.try_build()
.unwrap();
let plan = merge_insert_job.explain_plan(None, false).await.unwrap();
assert!(plan.contains("HashJoinExec"));
assert!(plan.contains("join_type=Full"));
assert!(plan.contains("projection=[_rowid"));
assert!(
plan.contains("LanceRead: uri=") && plan.contains("projection=[id]"),
"target-side scan should prune the FSL payload from the join build side: {plan}"
);
assert!(
!plan.contains("LanceRead: uri=test_explain_plan_full_schema_delete_by_source_with_fsl/data, projection=[id, vec]"),
"target-side scan should not include the FSL payload in the join build side: {plan}"
);
}
#[tokio::test]
async fn test_merge_insert_full_schema_delete_by_source_with_fsl() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"vec",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
true,
),
]));
let dataset_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(
FixedSizeListArray::try_new_from_values(
Float32Array::from(vec![
1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3,
]),
4,
)
.unwrap(),
),
],
)
.unwrap();
let dataset = Dataset::write(
Box::new(RecordBatchIterator::new(
[Ok(dataset_batch)],
schema.clone(),
)),
"memory://test_merge_insert_full_schema_delete_by_source_with_fsl",
None,
)
.await
.unwrap();
let source_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![2, 4])),
Arc::new(
FixedSizeListArray::try_new_from_values(
Float32Array::from(vec![20.0, 20.1, 20.2, 20.3, 40.0, 40.1, 40.2, 40.3]),
4,
)
.unwrap(),
),
],
)
.unwrap();
let (merged_dataset, stats) =
MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.when_not_matched_by_source(WhenNotMatchedBySource::Delete)
.try_build()
.unwrap()
.execute_reader(Box::new(RecordBatchIterator::new(
[Ok(source_batch)],
schema.clone(),
)))
.await
.unwrap();
assert_eq!(stats.num_deleted_rows, 2);
assert_eq!(stats.num_updated_rows, 1);
assert_eq!(stats.num_inserted_rows, 1);
let merged = merged_dataset.scan().try_into_batch().await.unwrap();
let ids = merged["id"].as_primitive::<Int32Type>().values().to_vec();
assert_eq!(ids, vec![2, 4]);
let vecs = merged["vec"].as_fixed_size_list();
let actual = vecs
.values()
.as_primitive::<Float32Type>()
.values()
.to_vec();
assert_eq!(actual, vec![20.0, 20.1, 20.2, 20.3, 40.0, 40.1, 40.2, 40.3]);
}
#[tokio::test]
async fn test_analyze_plan() {
let mut dataset = lance_datagen::gen_batch()
.col("id", lance_datagen::array::step::<Int32Type>())
.col("name", array::cycle_utf8_literals(&["a", "b", "c"]))
.into_ram_dataset(FragmentCount::from(1), FragmentRowCount::from(3))
.await
.unwrap();
let original_version = dataset.version().version;
let merge_insert_job =
MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let schema = Arc::new(arrow_schema::Schema::from(dataset.schema()));
let source_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 4])), Arc::new(StringArray::from(vec!["updated_a", "d"])),
],
)
.unwrap();
let source_stream = RecordBatchStreamAdapter::new(
schema,
futures::stream::once(async { Ok(source_batch) }).boxed(),
);
let mut analysis = String::from("[");
analysis.push_str(
&merge_insert_job
.analyze_plan(Box::pin(source_stream))
.await
.unwrap(),
);
analysis.push_str(&String::from("]"));
assert!(analysis.contains("MergeInsert"));
assert!(analysis.contains("metrics"));
assert!(analysis.contains("bytes_written"));
assert!(analysis.contains("num_files_written"));
dataset.checkout_latest().await.unwrap();
assert_eq!(
dataset.version().version,
original_version,
"analyze_plan should not create a new dataset version"
);
let expected_pattern = "[...MergeInsert: elapsed=..., on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep, metrics=...bytes_written=...num_deleted_rows=0, num_files_written=...num_inserted_rows=1, num_skipped_duplicates=0, num_updated_rows=1]
...
StreamingTableExec: partition_sizes=1, projection=[id, name], metrics=[]...]";
assert_string_matches(&analysis, expected_pattern).unwrap();
assert!(analysis.contains("bytes_written"));
assert!(analysis.contains("num_files_written"));
assert!(analysis.contains("elapsed_compute"));
}
#[tokio::test]
async fn test_merge_insert_with_action_column() {
let initial_data = RecordBatch::try_new(
Arc::new(arrow_schema::Schema::new(vec![
arrow_schema::Field::new("id", arrow_schema::DataType::Int32, false),
arrow_schema::Field::new("action", arrow_schema::DataType::Utf8, true),
arrow_schema::Field::new("value", arrow_schema::DataType::Int32, true),
])),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["create", "update", "delete"])),
Arc::new(Int32Array::from(vec![10, 20, 30])),
],
)
.unwrap();
let tempdir = TempStrDir::default();
let dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(initial_data.clone())], initial_data.schema()),
&tempdir,
None,
)
.await
.unwrap();
let new_data = RecordBatch::try_new(
Arc::new(arrow_schema::Schema::new(vec![
arrow_schema::Field::new("id", arrow_schema::DataType::Int32, false),
arrow_schema::Field::new("action", arrow_schema::DataType::Utf8, true),
arrow_schema::Field::new("value", arrow_schema::DataType::Int32, true),
])),
vec![
Arc::new(Int32Array::from(vec![2, 4])),
Arc::new(StringArray::from(vec!["modify", "insert"])),
Arc::new(Int32Array::from(vec![25, 40])),
],
)
.unwrap();
let merge_insert_job =
MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let new_reader = Box::new(RecordBatchIterator::new(
[Ok(new_data.clone())],
new_data.schema(),
));
let new_stream = reader_to_stream(new_reader);
let (merged_dataset, _) = merge_insert_job.execute(new_stream).await.unwrap();
let result_batches = merged_dataset
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let result_batch = concat_batches(&result_batches[0].schema(), &result_batches).unwrap();
assert_eq!(result_batch.num_rows(), 4);
let id_col = result_batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let action_col = result_batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let value_col = result_batch
.column(2)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
for i in 0..result_batch.num_rows() {
match id_col.value(i) {
1 => {
assert_eq!(action_col.value(i), "create");
assert_eq!(value_col.value(i), 10);
}
2 => {
assert_eq!(action_col.value(i), "modify"); assert_eq!(value_col.value(i), 25); }
3 => {
assert_eq!(action_col.value(i), "delete");
assert_eq!(value_col.value(i), 30);
}
4 => {
assert_eq!(action_col.value(i), "insert"); assert_eq!(value_col.value(i), 40); }
_ => panic!("Unexpected id: {}", id_col.value(i)),
}
}
}
#[tokio::test]
#[rstest::rstest]
async fn test_duplicate_rowid_detection(
#[values(false, true)] is_full_schema: bool,
#[values(true, false)] enable_stable_row_ids: bool,
#[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1, LanceFileVersion::V2_2)]
data_storage_version: LanceFileVersion,
) {
let test_uri = "memory://test_duplicate_rowid_multi_fragment.lance";
let dataset = lance_datagen::gen_batch()
.col("key", array::step_custom::<UInt32Type>(1, 1))
.col("value", array::step_custom::<UInt32Type>(10, 10))
.into_dataset_with_params(
test_uri,
FragmentCount(3),
FragmentRowCount(4),
Some(WriteParams {
max_rows_per_file: 4,
enable_stable_row_ids,
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.get_fragments().len(), 3, "Should have 3 fragments");
let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::UInt32, is_full_schema),
Field::new("value", DataType::UInt32, is_full_schema),
]));
let source_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![2, 2, 6, 6, 10, 10, 15])),
Arc::new(UInt32Array::from(vec![100, 200, 300, 400, 500, 600, 700])),
],
)
.unwrap();
let job = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.try_build()
.unwrap();
let reader = Box::new(RecordBatchIterator::new([Ok(source_batch)], schema.clone()));
let stream = reader_to_stream(reader);
let result = job.execute(stream).await;
assert!(
result.is_err(),
"Expected merge insert to fail due to duplicate rows on key column."
);
assert!(
matches!(&result, &Err(Error::InvalidInput { ref source, .. }) if source.to_string().contains("Ambiguous merge insert") && source.to_string().contains("multiple source rows")),
"Expected error to be InvalidInput with message about ambiguous merge insert and multiple source rows, got: {:?}",
result
);
}
#[tokio::test]
#[rstest::rstest]
async fn test_source_dedupe_behavior_first_seen(
#[values(false, true)] is_full_schema: bool,
#[values(true, false)] enable_stable_row_ids: bool,
#[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1, LanceFileVersion::V2_2)]
data_storage_version: LanceFileVersion,
) {
let test_uri = format!(
"memory://test_dedupe_first_seen_{}_{}.lance",
is_full_schema, enable_stable_row_ids
);
let dataset = lance_datagen::gen_batch()
.col("key", array::step_custom::<UInt32Type>(1, 1))
.col("value", array::step_custom::<UInt32Type>(10, 10))
.into_dataset_with_params(
&test_uri,
FragmentCount(1),
FragmentRowCount(4),
Some(WriteParams {
max_rows_per_file: 4,
enable_stable_row_ids,
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
let initial_data: Vec<(u32, u32)> = dataset
.scan()
.try_into_batch()
.await
.unwrap()
.columns()
.iter()
.map(|c| c.as_primitive::<UInt32Type>().values().to_vec())
.collect::<Vec<_>>()
.into_iter()
.fold(Vec::new(), |mut acc, vals| {
if acc.is_empty() {
acc = vals.into_iter().map(|v| (v, 0)).collect();
} else {
for (i, v) in vals.into_iter().enumerate() {
acc[i].1 = v;
}
}
acc
});
assert_eq!(
initial_data,
vec![(1, 10), (2, 20), (3, 30), (4, 40)],
"Initial data should be correct"
);
let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::UInt32, is_full_schema),
Field::new("value", DataType::UInt32, is_full_schema),
]));
let source_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![2, 2, 2, 3, 3, 5])),
Arc::new(UInt32Array::from(vec![100, 200, 300, 400, 500, 600])),
],
)
.unwrap();
let job = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.source_dedupe_behavior(SourceDedupeBehavior::FirstSeen)
.try_build()
.unwrap();
let reader = Box::new(RecordBatchIterator::new([Ok(source_batch)], schema.clone()));
let stream = reader_to_stream(reader);
let (dataset, stats) = job.execute(stream).await.unwrap();
assert_eq!(
stats.num_skipped_duplicates, 3,
"Should have skipped 3 duplicate rows (2 extra for key=2, 1 extra for key=3)"
);
assert_eq!(
stats.num_updated_rows, 2,
"Should have updated 2 rows (key=2 and key=3)"
);
assert_eq!(
stats.num_inserted_rows, 1,
"Should have inserted 1 row (key=5)"
);
let result_batch = dataset.scan().try_into_batch().await.unwrap();
let keys = result_batch.column(0).as_primitive::<UInt32Type>();
let values = result_batch.column(1).as_primitive::<UInt32Type>();
let result_data: std::collections::HashMap<u32, u32> = keys
.values()
.iter()
.zip(values.values().iter())
.map(|(&k, &v)| (k, v))
.collect();
assert_eq!(result_data.len(), 5, "Should have 5 rows total");
assert_eq!(
result_data.get(&1),
Some(&10),
"key=1 should be unchanged (original value)"
);
assert_eq!(
result_data.get(&2),
Some(&100),
"key=2 should have first seen value (100, not 200 or 300)"
);
assert_eq!(
result_data.get(&3),
Some(&400),
"key=3 should have first seen value (400, not 500)"
);
assert_eq!(
result_data.get(&4),
Some(&40),
"key=4 should be unchanged (original value)"
);
assert_eq!(
result_data.get(&5),
Some(&600),
"key=5 should be inserted with value 600"
);
}
#[tokio::test]
async fn test_merge_insert_use_index() {
let data = lance_datagen::gen_batch()
.col("id", lance_datagen::array::step::<Int32Type>())
.col("value", array::step::<UInt32Type>());
let data = data.into_reader_rows(RowCount::from(100), BatchCount::from(1));
let schema = data.schema();
let mut ds = Dataset::write(data, "memory://", None).await.unwrap();
let index_params = ScalarIndexParams::default();
ds.create_index(&["id"], IndexType::Scalar, None, &index_params, false)
.await
.unwrap();
let source_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 101])), Arc::new(UInt32Array::from(vec![999, 999, 999])),
],
)
.unwrap();
let merge_job_no_index =
MergeInsertBuilder::try_new(Arc::new(ds.clone()), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.use_index(false) .try_build()
.unwrap();
let plan = merge_job_no_index.explain_plan(None, false).await;
assert!(
plan.is_ok(),
"explain_plan should succeed with use_index=false"
);
let plan_str = plan.unwrap();
assert!(plan_str.contains("MergeInsert"));
assert!(plan_str.contains("HashJoinExec"));
let merge_job_with_index =
MergeInsertBuilder::try_new(Arc::new(ds.clone()), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.use_index(true) .try_build()
.unwrap();
let plan_result = merge_job_with_index.explain_plan(None, false).await;
assert!(
plan_result.is_err(),
"explain_plan should fail with use_index=true when index exists"
);
match plan_result {
Err(Error::NotSupported { source, .. }) => {
assert!(source.to_string().contains("does not support explain_plan"));
}
_ => panic!("Expected NotSupported error"),
}
let source = Box::new(RecordBatchIterator::new(
vec![Ok(source_batch.clone())],
schema.clone(),
));
let (result_ds, stats) = merge_job_no_index.execute_reader(source).await.unwrap();
assert_eq!(stats.num_updated_rows, 2);
assert_eq!(stats.num_inserted_rows, 1);
let updated_count = result_ds
.count_rows(Some("value = 999".to_string()))
.await
.unwrap();
assert_eq!(updated_count, 3);
}
#[tokio::test]
async fn test_full_schema_upsert_fragment_bitmap() {
let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::UInt32, true),
Field::new("value", DataType::UInt32, true),
Field::new(
"vec",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
true,
),
]));
let mut dataset = lance_datagen::gen_batch()
.col("key", array::step_custom::<UInt32Type>(1, 1))
.col("value", array::step_custom::<UInt32Type>(10, 10))
.col(
"vec",
array::cycle_vec(
array::cycle::<Float32Type>(vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
]),
Dimension::from(4),
),
)
.into_ram_dataset_with_params(
FragmentCount::from(2),
FragmentRowCount::from(3),
Some(WriteParams {
max_rows_per_file: 3,
enable_stable_row_ids: true,
..Default::default()
}),
)
.await
.unwrap();
let scalar_params = ScalarIndexParams::default();
dataset
.create_index(
&["value"],
IndexType::Scalar,
Some("value_idx".to_string()),
&scalar_params,
true,
)
.await
.unwrap();
let vector_params = VectorIndexParams::ivf_flat(1, MetricType::L2);
dataset
.create_index(
&["vec"],
IndexType::Vector,
Some("vec_idx".to_string()),
&vector_params,
true,
)
.await
.unwrap();
let indices = dataset.load_indices().await.unwrap();
let value_index = indices.iter().find(|idx| idx.name == "value_idx").unwrap();
let vec_index = indices.iter().find(|idx| idx.name == "vec_idx").unwrap();
assert_eq!(
value_index
.fragment_bitmap
.as_ref()
.unwrap()
.iter()
.collect::<Vec<_>>(),
vec![0, 1]
);
assert_eq!(
vec_index
.fragment_bitmap
.as_ref()
.unwrap()
.iter()
.collect::<Vec<_>>(),
vec![0, 1]
);
let upsert_keys = UInt32Array::from(vec![2, 5]);
let upsert_values = UInt32Array::from(vec![200, 500]);
let upsert_vecs = FixedSizeListArray::try_new_from_values(
Float32Array::from(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0]),
4,
)
.unwrap();
let upsert_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(upsert_keys),
Arc::new(upsert_values),
Arc::new(upsert_vecs),
],
)
.unwrap();
let upsert_stream = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::once(async { Ok(upsert_batch) }).boxed(),
);
let (updated_dataset, _stats) =
MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::DoNothing)
.when_not_matched_by_source(WhenNotMatchedBySource::Keep)
.try_build()
.unwrap()
.execute(Box::pin(upsert_stream))
.await
.unwrap();
let fragments = updated_dataset.get_fragments();
assert_eq!(fragments.len(), 3);
}
#[tokio::test]
async fn test_sub_schema_upsert_fragment_bitmap() {
let mut dataset = lance_datagen::gen_batch()
.col("key", array::step_custom::<UInt32Type>(1, 1))
.col("value", array::step_custom::<UInt32Type>(10, 10))
.col(
"vec",
array::cycle_vec(
array::cycle::<Float32Type>(vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
]),
Dimension::from(4),
),
)
.into_ram_dataset_with_params(
FragmentCount::from(2),
FragmentRowCount::from(3),
Some(WriteParams {
max_rows_per_file: 3,
enable_stable_row_ids: true,
..Default::default()
}),
)
.await
.unwrap();
let scalar_params = ScalarIndexParams::default();
dataset
.create_index(
&["value"],
IndexType::Scalar,
Some("value_idx".to_string()),
&scalar_params,
true,
)
.await
.unwrap();
let vector_params = VectorIndexParams::ivf_flat(1, MetricType::L2);
dataset
.create_index(
&["vec"],
IndexType::Vector,
Some("vec_idx".to_string()),
&vector_params,
true,
)
.await
.unwrap();
let indices = dataset.load_indices().await.unwrap();
let value_index = indices.iter().find(|idx| idx.name == "value_idx").unwrap();
let vec_index = indices.iter().find(|idx| idx.name == "vec_idx").unwrap();
assert_eq!(
value_index
.fragment_bitmap
.as_ref()
.unwrap()
.iter()
.collect::<Vec<_>>(),
vec![0, 1]
);
assert_eq!(
vec_index
.fragment_bitmap
.as_ref()
.unwrap()
.iter()
.collect::<Vec<_>>(),
vec![0, 1]
);
let sub_schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::UInt32, true),
Field::new(
"vec",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
true,
),
]));
let upsert_keys = UInt32Array::from(vec![2, 5]);
let upsert_vecs = FixedSizeListArray::try_new_from_values(
Float32Array::from(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0]),
4,
)
.unwrap();
let upsert_batch = RecordBatch::try_new(
sub_schema.clone(),
vec![Arc::new(upsert_keys), Arc::new(upsert_vecs)],
)
.unwrap();
let upsert_stream = RecordBatchStreamAdapter::new(
sub_schema.clone(),
futures::stream::once(async { Ok(upsert_batch) }).boxed(),
);
let (updated_dataset, _stats) =
MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::DoNothing)
.when_not_matched_by_source(WhenNotMatchedBySource::Keep)
.try_build()
.unwrap()
.execute(Box::pin(upsert_stream))
.await
.unwrap();
let fragments = updated_dataset.get_fragments();
assert_eq!(fragments.len(), 2);
let updated_indices = updated_dataset.load_indices().await.unwrap();
assert_eq!(updated_indices.len(), 1);
let updated_value_index = updated_indices
.iter()
.find(|idx| idx.name == "value_idx")
.unwrap();
let value_bitmap = updated_value_index.fragment_bitmap.as_ref().unwrap();
assert_eq!(value_bitmap.len(), 2);
assert!(value_bitmap.contains(0));
assert!(value_bitmap.contains(1));
}
#[tokio::test]
async fn test_when_matched_fail() {
let dataset = create_test_dataset("memory://test_fail", LanceFileVersion::V2_0, true).await;
let new_data = RecordBatch::try_new(
create_test_schema(),
vec![
Arc::new(UInt32Array::from(vec![1, 2, 10, 11])), Arc::new(UInt32Array::from(vec![100, 200, 1000, 1100])),
Arc::new(StringArray::from(vec!["X", "Y", "Z", "W"])),
],
)
.unwrap();
let reader = Box::new(RecordBatchIterator::new(
[Ok(new_data.clone())],
new_data.schema(),
));
let new_stream = reader_to_stream(reader);
let result = MergeInsertBuilder::try_new(dataset.clone(), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::Fail)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute(new_stream)
.await;
match result {
Ok((_dataset, stats)) => {
panic!(
"Expected merge insert to fail, but it succeeded. Stats: {:?}",
stats
);
}
Err(e) => {
let error_msg = e.to_string();
assert!(error_msg.contains("Merge insert failed"));
assert!(error_msg.contains("found matching row"));
}
}
let new_data = RecordBatch::try_new(
create_test_schema(),
vec![
Arc::new(UInt32Array::from(vec![10, 11, 12])), Arc::new(UInt32Array::from(vec![1000, 1100, 1200])),
Arc::new(StringArray::from(vec!["X", "Y", "Z"])),
],
)
.unwrap();
let reader = Box::new(RecordBatchIterator::new(
[Ok(new_data.clone())],
new_data.schema(),
));
let new_stream = reader_to_stream(reader);
let (updated_dataset, stats) =
MergeInsertBuilder::try_new(dataset.clone(), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::Fail)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute(new_stream)
.await
.unwrap();
assert_eq!(stats.num_inserted_rows, 3);
assert_eq!(stats.num_updated_rows, 0);
assert_eq!(stats.num_deleted_rows, 0);
let count = updated_dataset
.count_rows(Some("key >= 10".to_string()))
.await
.unwrap();
assert_eq!(count, 3);
}
#[tokio::test]
async fn test_merge_insert_permissive_nullability() {
let non_nullable_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false), Field::new("value", DataType::Int64, false), ]));
let initial_data = RecordBatch::try_new(
non_nullable_schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(Int64Array::from(vec![100, 200, 300])),
],
)
.unwrap();
let test_uri = "memory://test_nullable_issue_4654";
let dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(initial_data)], non_nullable_schema.clone()),
test_uri,
None,
)
.await
.unwrap();
let nullable_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, true), Field::new("value", DataType::Int64, true), ]));
let new_data = RecordBatch::try_new(
nullable_schema.clone(),
vec![
Arc::new(Int64Array::from(vec![2, 4, 5])), Arc::new(Int64Array::from(vec![999, 400, 500])), ],
)
.unwrap();
let merge_result = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute_reader(Box::new(RecordBatchIterator::new(
vec![Ok(new_data.clone())],
nullable_schema.clone(),
)))
.await;
assert!(
merge_result.is_ok(),
"merge_insert() should succeed with nullable fields but no actual nulls. \
This is the same behavior as insert/append. Error: {:?}",
merge_result.err()
);
let (merged_dataset, stats) = merge_result.unwrap();
assert_eq!(stats.num_updated_rows, 1, "Should update 1 row (id=2)");
assert_eq!(
stats.num_inserted_rows, 2,
"Should insert 2 new rows (id=4,5)"
);
let count = merged_dataset.count_rows(None).await.unwrap();
assert_eq!(count, 5, "Should have 5 total rows");
let result = merged_dataset
.scan()
.filter("id = 2")
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = concat_batches(&result[0].schema(), &result).unwrap();
assert_eq!(batch.num_rows(), 1);
let value_array = batch
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(
value_array.value(0),
999,
"Value for id=2 should be updated to 999"
);
}
#[tokio::test]
async fn test_merge_insert_subschema_invalid_type_error() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("value", DataType::Float64, true), Field::new("extra", DataType::Utf8, true),
]));
let initial_data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Float64Array::from(vec![1.1, 2.2, 3.3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)
.unwrap();
let test_uri = "memory://test_issue_3634";
let dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(initial_data)], schema),
test_uri,
None,
)
.await
.unwrap();
let subschema_with_wrong_type = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("value", DataType::Int32, true),
]));
let new_data = RecordBatch::try_new(
subschema_with_wrong_type.clone(),
vec![
Arc::new(Int32Array::from(vec![2, 4])),
Arc::new(Int32Array::from(vec![22, 44])),
],
)
.unwrap();
let merge_result = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute_reader(Box::new(RecordBatchIterator::new(
vec![Ok(new_data)],
subschema_with_wrong_type,
)))
.await;
let err = merge_result.expect_err("Merge insert should have failed but it succeeded.");
assert!(
matches!(err, lance_core::Error::SchemaMismatch { .. }),
"Expected a SchemaMismatch error, but got a different error type: {:?}",
err
);
let error_message = err.to_string();
assert!(
error_message.contains("`value` should have type double but type was int32"),
"Error message should specify the expected (double) and actual (int32) types for 'value', but was: {}",
error_message
);
assert!(
!error_message.contains("missing="),
"Error message should NOT complain about missing fields for a subschema check, but was: {}",
error_message
);
}
#[tokio::test]
async fn test_merge_insert_mixed_case_key() {
let schema = Arc::new(Schema::new(vec![
Field::new("userId", DataType::UInt32, false),
Field::new("value", DataType::UInt32, true),
]));
let initial_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![1, 2, 3])),
Arc::new(UInt32Array::from(vec![10, 20, 30])),
],
)
.unwrap();
let test_uri = "memory://test_mixed_case.lance";
let ds = Dataset::write(
RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone()),
test_uri,
None,
)
.await
.unwrap();
let new_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![2, 4])),
Arc::new(UInt32Array::from(vec![200, 400])),
],
)
.unwrap();
let job = MergeInsertBuilder::try_new(Arc::new(ds), vec!["userId".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.try_build()
.unwrap();
let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
let new_stream = reader_to_stream(new_reader);
let (merged_ds, _merge_stats) = job.execute(new_stream).await.unwrap();
let result = merged_ds
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let result_batch = concat_batches(&schema, &result).unwrap();
assert_eq!(result_batch.num_rows(), 4);
let user_ids = result_batch
.column(0)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
let values = result_batch
.column(1)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
for i in 0..result_batch.num_rows() {
if user_ids.value(i) == 2 {
assert_eq!(
values.value(i),
200,
"userId=2 should have been updated to value=200"
);
}
}
}
#[tokio::test]
async fn test_merge_insert_reordered_columns() {
use arrow_array::record_batch;
let initial_data = record_batch!(
("id", Int32, [1, 2, 3]),
("value", Float64, [1.1, 2.2, 3.3]),
("extra", Utf8, ["a", "b", "c"])
)
.unwrap();
let dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(initial_data.clone())], initial_data.schema()),
"memory://test_issue_5323",
None,
)
.await
.unwrap();
let new_data = record_batch!(
("extra", Utf8, ["x", "y"]),
("id", Int32, [2, 4]), ("value", Float64, [22.2, 44.4])
)
.unwrap();
let job = MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
assert!(
job.can_use_create_plan(&new_data.schema()).await.unwrap(),
"Reordered schema should be able to use fast path"
);
let (merged_dataset, _) =
MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap()
.execute_reader(Box::new(RecordBatchIterator::new(
vec![Ok(new_data.clone())],
new_data.schema(),
)))
.await
.unwrap();
let result = merged_dataset
.scan()
.order_by(Some(vec![ColumnOrdering::asc_nulls_first(
"id".to_string(),
)]))
.unwrap()
.try_into_batch()
.await
.unwrap();
let expected = record_batch!(
("id", Int32, [1, 2, 3, 4]),
("value", Float64, [1.1, 22.2, 3.3, 44.4]),
("extra", Utf8, ["a", "x", "c", "y"])
)
.unwrap();
assert_eq!(result, expected);
}
#[rstest::rstest]
#[tokio::test]
async fn test_when_matched_delete_full_schema(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
#[values(true, false)] enable_stable_row_ids: bool,
) {
let schema = create_test_schema();
let test_uri = "memory://test_delete_full.lance";
let ds = create_test_dataset(test_uri, version, enable_stable_row_ids).await;
let new_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![4, 5, 6, 7, 8, 9])),
Arc::new(UInt32Array::from(vec![2, 2, 2, 2, 2, 2])),
Arc::new(StringArray::from(vec!["A", "B", "C", "A", "B", "C"])),
],
)
.unwrap();
let keys = vec!["key".to_string()];
let plan_job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::Delete)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
[Ok(new_batch.clone())],
schema.clone(),
)));
let plan = plan_job.create_plan(plan_stream).await.unwrap();
assert_plan_node_equals(
plan,
"DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing
...
HashJoinExec: ...join_type=Inner...
...
...
StreamingTableExec: partition_sizes=1, projection=[key]",
)
.await
.unwrap();
let job = MergeInsertBuilder::try_new(ds.clone(), keys)
.unwrap()
.when_matched(WhenMatched::Delete)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
let new_stream = reader_to_stream(new_reader);
let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
assert_eq!(merge_stats.num_deleted_rows, 3);
assert_eq!(merge_stats.num_inserted_rows, 0);
assert_eq!(merge_stats.num_updated_rows, 0);
let batches = merged_dataset
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let merged = concat_batches(&schema, &batches).unwrap();
let mut remaining_keys: Vec<u32> = merged
.column(0)
.as_primitive::<UInt32Type>()
.values()
.to_vec();
remaining_keys.sort();
assert_eq!(remaining_keys, vec![1, 2, 3]);
}
#[rstest::rstest]
#[tokio::test]
async fn test_when_matched_delete_id_only(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
#[values(true, false)] enable_stable_row_ids: bool,
) {
let test_uri = "memory://test_delete_id_only.lance";
let ds = create_test_dataset(test_uri, version, enable_stable_row_ids).await;
let id_only_schema = Arc::new(Schema::new(vec![Field::new("key", DataType::UInt32, true)]));
let new_batch = RecordBatch::try_new(
id_only_schema.clone(),
vec![Arc::new(UInt32Array::from(vec![2, 4, 6]))], )
.unwrap();
let keys = vec!["key".to_string()];
let plan_job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::Delete)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
[Ok(new_batch.clone())],
id_only_schema.clone(),
)));
let plan = plan_job.create_plan(plan_stream).await.unwrap();
assert_plan_node_equals(
plan,
"DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing
...
HashJoinExec: ...join_type=Inner...
...
...
StreamingTableExec: partition_sizes=1, projection=[key]",
)
.await
.unwrap();
let job = MergeInsertBuilder::try_new(ds.clone(), keys)
.unwrap()
.when_matched(WhenMatched::Delete)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let new_reader = Box::new(RecordBatchIterator::new(
[Ok(new_batch)],
id_only_schema.clone(),
));
let new_stream = reader_to_stream(new_reader);
let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
assert_eq!(merge_stats.num_deleted_rows, 3);
assert_eq!(merge_stats.num_inserted_rows, 0);
assert_eq!(merge_stats.num_updated_rows, 0);
let full_schema = create_test_schema();
let batches = merged_dataset
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let merged = concat_batches(&full_schema, &batches).unwrap();
let mut remaining_keys: Vec<u32> = merged
.column(0)
.as_primitive::<UInt32Type>()
.values()
.to_vec();
remaining_keys.sort();
assert_eq!(remaining_keys, vec![1, 3, 5]);
}
#[rstest::rstest]
#[tokio::test]
async fn test_when_matched_delete_with_insert(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
) {
let schema = create_test_schema();
let test_uri = "memory://test_delete_with_insert.lance";
let ds = create_test_dataset(test_uri, version, false).await;
let new_batch = create_new_batch(schema.clone());
let keys = vec!["key".to_string()];
let plan_job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::Delete)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
[Ok(new_batch.clone())],
schema.clone(),
)));
let plan = plan_job.create_plan(plan_stream).await.unwrap();
assert_plan_node_equals(
plan,
"MergeInsert: on=[key], when_matched=Delete, when_not_matched=InsertAll, when_not_matched_by_source=Keep...THEN 2 WHEN...THEN 3 ELSE 0 END as __action]...projection=[key, value, filterme]"
).await.unwrap();
let job = MergeInsertBuilder::try_new(ds.clone(), keys)
.unwrap()
.when_matched(WhenMatched::Delete)
.when_not_matched(WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
let new_stream = reader_to_stream(new_reader);
let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
assert_eq!(merge_stats.num_deleted_rows, 3);
assert_eq!(merge_stats.num_inserted_rows, 3);
assert_eq!(merge_stats.num_updated_rows, 0);
let batches = merged_dataset
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let merged = concat_batches(&schema, &batches).unwrap();
let mut remaining_keys: Vec<u32> = merged
.column(0)
.as_primitive::<UInt32Type>()
.values()
.to_vec();
remaining_keys.sort();
assert_eq!(remaining_keys, vec![1, 2, 3, 7, 8, 9]);
let keyvals: Vec<(u32, u32)> = merged
.column(0)
.as_primitive::<UInt32Type>()
.values()
.iter()
.zip(
merged
.column(1)
.as_primitive::<UInt32Type>()
.values()
.iter(),
)
.map(|(&k, &v)| (k, v))
.collect();
for (key, value) in keyvals {
if key <= 3 {
assert_eq!(value, 1, "Original keys should have value=1");
} else {
assert_eq!(value, 2, "New keys should have value=2");
}
}
}
#[rstest::rstest]
#[tokio::test]
async fn test_when_matched_delete_no_matches(
#[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
) {
let schema = create_test_schema();
let test_uri = "memory://test_delete_no_matches.lance";
let ds = create_test_dataset(test_uri, version, false).await;
let non_matching_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![100, 200, 300])),
Arc::new(UInt32Array::from(vec![10, 20, 30])),
Arc::new(StringArray::from(vec!["X", "Y", "Z"])),
],
)
.unwrap();
let keys = vec!["key".to_string()];
let plan_job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
.unwrap()
.when_matched(WhenMatched::Delete)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
[Ok(non_matching_batch.clone())],
schema.clone(),
)));
let plan = plan_job.create_plan(plan_stream).await.unwrap();
assert_plan_node_equals(
plan,
"DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing
...
HashJoinExec: ...join_type=Inner...
...
...
StreamingTableExec: partition_sizes=1, projection=[key]",
)
.await
.unwrap();
let job = MergeInsertBuilder::try_new(ds.clone(), keys)
.unwrap()
.when_matched(WhenMatched::Delete)
.when_not_matched(WhenNotMatched::DoNothing)
.try_build()
.unwrap();
let new_reader = Box::new(RecordBatchIterator::new(
[Ok(non_matching_batch)],
schema.clone(),
));
let new_stream = reader_to_stream(new_reader);
let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
assert_eq!(merge_stats.num_deleted_rows, 0);
assert_eq!(merge_stats.num_inserted_rows, 0);
assert_eq!(merge_stats.num_updated_rows, 0);
let batches = merged_dataset
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let merged = concat_batches(&schema, &batches).unwrap();
let mut remaining_keys: Vec<u32> = merged
.column(0)
.as_primitive::<UInt32Type>()
.values()
.to_vec();
remaining_keys.sort();
assert_eq!(remaining_keys, vec![1, 2, 3, 4, 5, 6]);
}
#[tokio::test]
async fn test_is_delete_only() {
use itertools::iproduct;
let when_matched_variants = [
WhenMatched::UpdateAll,
WhenMatched::DoNothing,
WhenMatched::Fail,
WhenMatched::Delete,
];
let when_not_matched_variants = [WhenNotMatched::InsertAll, WhenNotMatched::DoNothing];
let when_not_matched_by_source_variants =
[WhenNotMatchedBySource::Keep, WhenNotMatchedBySource::Delete];
let schema = create_test_schema();
for (idx, (when_matched, when_not_matched, when_not_matched_by_source)) in iproduct!(
when_matched_variants.iter().cloned(),
when_not_matched_variants.iter().cloned(),
when_not_matched_by_source_variants.iter().cloned()
)
.enumerate()
{
let is_no_op = matches!(when_matched, WhenMatched::DoNothing | WhenMatched::Fail)
&& matches!(when_not_matched, WhenNotMatched::DoNothing)
&& matches!(when_not_matched_by_source, WhenNotMatchedBySource::Keep);
if is_no_op {
continue;
}
let test_uri = format!("memory://test_is_delete_only_{}.lance", idx);
let ds = create_test_dataset(&test_uri, LanceFileVersion::V2_0, false).await;
let new_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![4, 5, 6])),
Arc::new(UInt32Array::from(vec![2, 2, 2])),
Arc::new(StringArray::from(vec!["A", "B", "C"])),
],
)
.unwrap();
let keys = vec!["key".to_string()];
let mut builder = MergeInsertBuilder::try_new(ds.clone(), keys).unwrap();
builder
.when_matched(when_matched.clone())
.when_not_matched(when_not_matched.clone())
.when_not_matched_by_source(when_not_matched_by_source.clone());
let job = builder.try_build().unwrap();
let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
[Ok(new_batch)],
schema.clone(),
)));
let plan = job.create_plan(plan_stream).await.unwrap();
let plan_str = datafusion::physical_plan::displayable(plan.as_ref())
.indent(true)
.to_string();
let expected_delete_only = matches!(when_matched, WhenMatched::Delete)
&& matches!(when_not_matched, WhenNotMatched::DoNothing)
&& matches!(when_not_matched_by_source, WhenNotMatchedBySource::Keep);
if expected_delete_only {
assert!(
plan_str.contains("DeleteOnlyMergeInsert"),
"Expected DeleteOnlyMergeInsert for ({:?}, {:?}, {:?}), but got:\n{}",
when_matched,
when_not_matched,
when_not_matched_by_source,
plan_str
);
} else {
assert!(
plan_str.contains("MergeInsert:")
&& !plan_str.contains("DeleteOnlyMergeInsert"),
"Expected MergeInsert (not DeleteOnlyMergeInsert) for ({:?}, {:?}, {:?}), but got:\n{}",
when_matched,
when_not_matched,
when_not_matched_by_source,
plan_str
);
}
}
}
#[tokio::test]
async fn test_apply_deletions_invalid_row_address() {
use super::exec::apply_deletions;
use roaring::RoaringTreemap;
let test_uri = "memory://test_apply_deletions_error.lance";
let ds = create_test_dataset(test_uri, LanceFileVersion::V2_0, false).await;
let fragment_id = ds.get_fragments()[0].id() as u32;
let mut invalid_row_addrs = RoaringTreemap::new();
let base = (fragment_id as u64) << 32;
for row_offset in 10..14u64 {
invalid_row_addrs.insert(base | row_offset);
}
let result = apply_deletions(&ds, &invalid_row_addrs).await;
assert!(result.is_err(), "Expected error for invalid row addresses");
let err = result.unwrap_err();
assert!(
err.to_string()
.contains("Deletion vector includes rows that aren't in the fragment"),
"Expected 'rows that aren't in the fragment' error, got: {}",
err
);
}
mod external_error {
use super::*;
use arrow_schema::{ArrowError, Field as ArrowField, Schema as ArrowSchema};
use std::fmt;
#[derive(Debug)]
struct MyTestError {
code: i32,
details: String,
}
impl fmt::Display for MyTestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MyTestError({}): {}", self.code, self.details)
}
}
impl std::error::Error for MyTestError {}
#[tokio::test]
async fn test_merge_insert_execute_reader_preserves_external_error() {
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("key", DataType::Int32, false),
ArrowField::new("value", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![10, 20, 30])),
],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let dataset = Arc::new(
Dataset::write(reader, "memory://test_merge_external", None)
.await
.unwrap(),
);
let error_code = 789;
let iter = std::iter::once(Err(ArrowError::ExternalError(Box::new(MyTestError {
code: error_code,
details: "merge insert failure".to_string(),
}))));
let reader = RecordBatchIterator::new(iter, schema);
let result = MergeInsertBuilder::try_new(dataset, vec!["key".to_string()])
.unwrap()
.try_build()
.unwrap()
.execute_reader(Box::new(reader) as Box<dyn RecordBatchReader + Send>)
.await;
match result {
Err(Error::External { source }) => {
let original = source.downcast_ref::<MyTestError>().unwrap();
assert_eq!(original.code, error_code);
}
Err(other) => panic!("Expected External, got: {:?}", other),
Ok(_) => panic!("Expected error"),
}
}
}
}