1const MERGE_ACTION_COLUMN: &str = "__action";
21pub(super) const MERGE_SOURCE_SENTINEL: &str = "__merge_source_sentinel";
38
39pub mod inserted_rows;
40
41use assign_action::merge_insert_action;
42use inserted_rows::KeyExistenceFilter;
43
44use super::retry::{RetryConfig, RetryExecutor, execute_with_retry};
45use super::{CommitBuilder, WriteParams, write_fragments_internal};
46use crate::dataset::rowids::get_row_id_index;
47use crate::dataset::transaction::UpdateMode::{RewriteColumns, RewriteRows};
48use crate::dataset::utils::CapturedRowIds;
49use crate::index::DatasetIndexExt;
50use crate::{
51 Dataset,
52 datafusion::dataframe::SessionContextExt,
53 dataset::{
54 fragment::{FileFragment, FragReadConfig},
55 transaction::{Operation, Transaction},
56 write::{merge_insert::logical_plan::MergeInsertPlanner, open_writer},
57 },
58 index::DatasetIndexInternalExt,
59 io::exec::{
60 AddRowAddrExec, Planner, TakeExec, project, scalar_index::MapIndexExec, utils::ReplayExec,
61 },
62};
63use arrow_array::{
64 BooleanArray, RecordBatch, RecordBatchIterator, StructArray, UInt32Array, UInt64Array,
65 cast::AsArray, types::UInt64Type,
66};
67use arrow_schema::{DataType, Field, Schema};
68use arrow_select::take::take_record_batch;
69use datafusion::common::NullEquality;
70use datafusion::common::tree_node::{Transformed, TreeNode};
71use datafusion::error::DataFusionError;
72use datafusion::{
73 execution::{
74 context::{SessionConfig, SessionContext},
75 memory_pool::MemoryConsumer,
76 },
77 logical_expr::{self, Expr, Extension, JoinType, LogicalPlan},
78 physical_plan::{
79 ColumnarValue, ExecutionPlan, PhysicalExpr, SendableRecordBatchStream,
80 display::DisplayableExecutionPlan,
81 joins::{HashJoinExec, PartitionMode},
82 projection::ProjectionExec,
83 repartition::RepartitionExec,
84 sorts::sort::SortExec,
85 stream::RecordBatchStreamAdapter,
86 union::UnionExec,
87 },
88 physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
89 prelude::DataFrame,
90 scalar::ScalarValue,
91};
92use datafusion_physical_expr::expressions::Column;
93use futures::{
94 Stream, StreamExt, TryStreamExt,
95 stream::{self},
96};
97use lance_arrow::{RecordBatchExt, SchemaExt, interleave_batches};
98use lance_core::datatypes::NullabilityComparison;
99use lance_core::utils::address::RowAddress;
100use lance_core::{
101 Error, ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD, Result,
102 datatypes::{OnMissing, OnTypeMismatch, SchemaCompareOptions},
103 error::{InvalidInputSnafu, box_error},
104 utils::{futures::Capacity, mask::RowAddrTreeMap, tokio::get_num_compute_intensive_cpus},
105};
106use lance_datafusion::{
107 chunker::chunk_stream,
108 dataframe::BatchStreamGrouper,
109 exec::{
110 HardCapBatchSizeExec, LanceExecutionOptions, OneShotExec, analyze_plan, execute_plan,
111 get_session_context,
112 },
113 utils::{StreamingWriteSource, reader_to_stream},
114};
115use lance_file::version::LanceFileVersion;
116use lance_index::IndexCriteria;
117use lance_index::mem_wal::MergedGeneration;
118use lance_table::format::{Fragment, IndexMetadata, RowIdMeta};
119use log::info;
120use roaring::RoaringTreemap;
121use snafu::ResultExt;
122use std::{
123 collections::{BTreeMap, HashSet},
124 sync::{
125 Arc, Mutex,
126 atomic::{AtomicU32, Ordering},
127 },
128 time::Duration,
129};
130use tokio::task::JoinSet;
131use tracing::error;
132
133mod assign_action;
134mod exec;
135mod logical_plan;
136
137fn combined_schema(schema: &Schema) -> Schema {
141 let target = Field::new("target", DataType::Struct(schema.fields.clone()), false);
142 let source = Field::new("source", DataType::Struct(schema.fields.clone()), false);
143 Schema::new(vec![source, target])
144}
145
146fn unzip_batch(batch: &RecordBatch, schema: &Schema) -> RecordBatch {
150 let num_fields = batch.num_columns();
154 debug_assert_eq!(num_fields % 2, 1);
155 let half_num_fields = num_fields / 2;
156 let row_id_col = num_fields - 1;
157
158 let source_arrays = batch.columns()[0..half_num_fields].to_vec();
159 let source = StructArray::new(schema.fields.clone(), source_arrays, None);
160
161 let target_arrays = batch.columns()[half_num_fields..row_id_col].to_vec();
162 let target = StructArray::new(schema.fields.clone(), target_arrays, None);
163
164 let combined_schema = combined_schema(schema);
165 RecordBatch::try_new(
166 Arc::new(combined_schema),
167 vec![Arc::new(source), Arc::new(target)],
168 )
169 .unwrap()
170}
171
172pub fn format_key_values_on_columns(
174 batch: &RecordBatch,
175 row_idx: usize,
176 on_columns: &[String],
177) -> String {
178 let mut on_values = Vec::new();
179
180 for col_name in on_columns {
181 if let Some(col_idx) = batch.schema().column_with_name(col_name) {
182 let column = batch.column(col_idx.0);
183 let value_str = if column.is_null(row_idx) {
184 "NULL".to_string()
185 } else {
186 match ScalarValue::try_from_array(column, row_idx) {
188 Ok(scalar_value) => match &scalar_value {
189 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
190 format!("\"{}\"", s)
191 }
192 _ => scalar_value.to_string(),
193 },
194 Err(_) => format!("<{:?}>", column.data_type()),
195 }
196 };
197 on_values.push(format!("{} = {}", col_name, value_str));
198 }
199 }
200
201 if on_values.is_empty() {
202 "<unable to extract on column values>".to_string()
203 } else {
204 on_values.join(", ")
205 }
206}
207
208pub fn create_duplicate_row_error(
210 batch: &RecordBatch,
211 row_idx: usize,
212 on_columns: &[String],
213) -> DataFusionError {
214 DataFusionError::External(Box::new(Error::invalid_input(format!(
215 "Ambiguous merge inserts are prohibited: multiple source rows match the same target row on ({}). \
216 Please ensure each target row is matched by at most one source row.",
217 format_key_values_on_columns(batch, row_idx, on_columns)
218 ))))
219}
220
221#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
225pub enum WhenNotMatchedBySource {
226 Keep,
230 Delete,
232 DeleteIf(Expr),
236}
237
238impl WhenNotMatchedBySource {
239 pub fn delete_if(dataset: &Dataset, expr: &str) -> Result<Self> {
245 let planner = Planner::new(Arc::new(dataset.schema().into()));
246 let expr = planner
247 .parse_filter(expr)
248 .map_err(box_error)
249 .context(InvalidInputSnafu {})?;
250 let expr = planner
251 .optimize_expr(expr)
252 .map_err(box_error)
253 .context(InvalidInputSnafu {})?;
254 Ok(Self::DeleteIf(expr))
255 }
256}
257
258#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
260pub enum WhenMatched {
261 UpdateAll,
265 DoNothing,
269 UpdateIf(String),
272 Fail,
276 Delete,
281}
282
283impl WhenMatched {
284 pub fn update_if(_dataset: &Dataset, expr: &str) -> Result<Self> {
285 Ok(Self::UpdateIf(expr.to_string()))
287 }
288}
289
290#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
294pub enum WhenNotMatched {
295 InsertAll,
299 DoNothing,
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
308pub enum SourceDedupeBehavior {
309 #[default]
311 Fail,
312 FirstSeen,
314}
315
316#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
317struct MergeInsertParams {
318 on: Vec<String>,
320 when_matched: WhenMatched,
322 insert_not_matched: bool,
324 delete_not_matched_by_source: WhenNotMatchedBySource,
326 conflict_retries: u32,
327 retry_timeout: Duration,
328 merged_generations: Vec<MergedGeneration>,
330 skip_auto_cleanup: bool,
335 use_index: bool,
338 source_dedupe_behavior: SourceDedupeBehavior,
340 commit_retries: Option<u32>,
342}
343
344#[derive(Clone)]
347pub struct MergeInsertJob {
348 dataset: Arc<Dataset>,
350 params: MergeInsertParams,
352}
353
354#[derive(Debug, Clone)]
398pub struct MergeInsertBuilder {
399 dataset: Arc<Dataset>,
400 params: MergeInsertParams,
401}
402
403impl MergeInsertBuilder {
404 pub fn try_new(dataset: Arc<Dataset>, on: Vec<String>) -> Result<Self> {
413 let resolved_on = if on.is_empty() {
416 let schema = dataset.schema();
417 let pk_fields = schema.unenforced_primary_key();
418
419 if pk_fields.is_empty() {
420 return Err(Error::invalid_input(
421 "A merge insert operation requires join keys: specify `on` columns explicitly or configure a primary key in the dataset schema",
422 ));
423 }
424
425 pk_fields
426 .iter()
427 .map(|field| schema.field_path(field.id))
428 .collect::<Result<Vec<_>>>()?
429 } else {
430 on.iter()
433 .map(|col| {
434 dataset
435 .schema()
436 .field_case_insensitive(col)
437 .map(|f| f.name.clone())
438 .ok_or_else(|| {
439 Error::invalid_input(format!(
440 "Merge insert key column '{}' does not exist in schema",
441 col
442 ))
443 })
444 })
445 .collect::<Result<Vec<_>>>()?
446 };
447
448 Ok(Self {
449 dataset,
450 params: MergeInsertParams {
451 on: resolved_on,
452 when_matched: WhenMatched::DoNothing,
453 insert_not_matched: true,
454 delete_not_matched_by_source: WhenNotMatchedBySource::Keep,
455 conflict_retries: 10,
456 retry_timeout: Duration::from_secs(30),
457 merged_generations: Vec::new(),
458 skip_auto_cleanup: false,
459 use_index: true,
460 source_dedupe_behavior: SourceDedupeBehavior::Fail,
461 commit_retries: None,
462 },
463 })
464 }
465
466 pub fn when_matched(&mut self, behavior: WhenMatched) -> &mut Self {
468 self.params.when_matched = behavior;
469 self
470 }
471
472 pub fn when_not_matched(&mut self, behavior: WhenNotMatched) -> &mut Self {
476 self.params.insert_not_matched = match behavior {
477 WhenNotMatched::DoNothing => false,
478 WhenNotMatched::InsertAll => true,
479 };
480 self
481 }
482
483 pub fn when_not_matched_by_source(&mut self, behavior: WhenNotMatchedBySource) -> &mut Self {
487 self.params.delete_not_matched_by_source = behavior;
488 self
489 }
490
491 pub fn conflict_retries(&mut self, retries: u32) -> &mut Self {
499 self.params.conflict_retries = retries;
500 self
501 }
502
503 pub fn retry_timeout(&mut self, timeout: Duration) -> &mut Self {
513 self.params.retry_timeout = timeout;
514 self
515 }
516
517 pub fn skip_auto_cleanup(&mut self, skip: bool) -> &mut Self {
518 self.params.skip_auto_cleanup = skip;
519 self
520 }
521
522 pub fn use_index(&mut self, use_index: bool) -> &mut Self {
529 self.params.use_index = use_index;
530 self
531 }
532
533 pub fn source_dedupe_behavior(&mut self, behavior: SourceDedupeBehavior) -> &mut Self {
541 self.params.source_dedupe_behavior = behavior;
542 self
543 }
544
545 pub fn mark_generations_as_merged(&mut self, generations: Vec<MergedGeneration>) -> &mut Self {
548 self.params.merged_generations.extend(generations);
549 self
550 }
551
552 pub fn commit_retries(&mut self, retries: u32) -> &mut Self {
556 self.params.commit_retries = Some(retries);
557 self
558 }
559
560 pub fn try_build(&mut self) -> Result<MergeInsertJob> {
562 if !self.params.insert_not_matched
563 && self.params.when_matched == WhenMatched::DoNothing
564 && self.params.delete_not_matched_by_source == WhenNotMatchedBySource::Keep
565 {
566 return Err(Error::invalid_input(
567 "The merge insert job is not configured to change the data in any way",
568 ));
569 }
570 Ok(MergeInsertJob {
571 dataset: self.dataset.clone(),
572 params: self.params.clone(),
573 })
574 }
575}
576
577enum SchemaComparison {
578 FullCompatible,
579 Subschema,
580}
581
582impl MergeInsertJob {
583 pub async fn execute_reader(
584 self,
585 source: impl StreamingWriteSource,
586 ) -> Result<(Arc<Dataset>, MergeStats)> {
587 let stream = source.into_stream();
588 self.execute(stream).await
589 }
590
591 fn check_compatible_schema(&self, schema: &Schema) -> Result<SchemaComparison> {
592 let lance_schema: lance_core::datatypes::Schema = schema.try_into()?;
593 let target_schema = self.dataset.schema();
594
595 let mut options = SchemaCompareOptions {
596 compare_dictionary: self.dataset.is_legacy_storage(),
597 compare_nullability: NullabilityComparison::Ignore,
598 ..Default::default()
599 };
600
601 if lance_schema
603 .check_compatible(target_schema, &options)
604 .is_ok()
605 {
606 return Ok(SchemaComparison::FullCompatible);
607 }
608
609 options.allow_subschema = true;
611 options.ignore_field_order = true; lance_schema
614 .check_compatible(target_schema, &options)
615 .map(|_| SchemaComparison::Subschema)
616 }
617
618 async fn join_key_as_scalar_index(&self) -> Result<Option<IndexMetadata>> {
619 if self.params.on.len() != 1 {
620 Ok(None)
622 } else {
623 let col = &self.params.on[0];
624 self.dataset
625 .load_scalar_index(
626 IndexCriteria::default()
627 .for_column(col)
628 .supports_exact_equality(),
630 )
631 .await
632 }
633 }
634
635 async fn create_indexed_scan_joined_stream(
636 &self,
637 source: SendableRecordBatchStream,
638 index: IndexMetadata,
639 ) -> Result<SendableRecordBatchStream> {
640 let schema = source.schema();
643 let add_row_addr = match self.check_compatible_schema(&schema)? {
644 SchemaComparison::FullCompatible => false,
645 SchemaComparison::Subschema => true,
646 };
647
648 let input = Arc::new(OneShotExec::new(source));
650
651 let shared_input = Arc::new(ReplayExec::new(Capacity::Unbounded, input));
655
656 let field = schema.field_with_name(&self.params.on[0])?;
659 let index_mapper_input = Arc::new(project(
660 shared_input.clone(),
661 &Schema::new(vec![field.clone()]),
663 )?);
664
665 let index_column = self.params.on[0].clone();
667 let mut index_mapper: Arc<dyn ExecutionPlan> = Arc::new(MapIndexExec::new(
668 self.dataset.clone(),
670 index_column.clone(),
671 index.name.clone(),
672 index_mapper_input,
673 ));
674
675 if add_row_addr {
677 let pos = index_mapper.schema().fields().len(); index_mapper = Arc::new(AddRowAddrExec::try_new(
679 index_mapper,
680 self.dataset.clone(),
681 pos,
682 )?);
683 }
684
685 let projection = self
687 .dataset
688 .empty_projection()
689 .union_arrow_schema(schema.as_ref(), OnMissing::Error)?;
690 let mut target =
691 Arc::new(TakeExec::try_new(self.dataset.clone(), index_mapper, projection)?.unwrap())
692 as Arc<dyn ExecutionPlan>;
693
694 let schema = target.schema();
698 let mut columns = schema
699 .fields()
700 .iter()
701 .filter(|f| f.name() != ROW_ID && f.name() != ROW_ADDR)
702 .cloned()
703 .collect::<Vec<_>>();
704 columns.push(Arc::new(ROW_ID_FIELD.clone()));
705 if add_row_addr {
706 columns.push(Arc::new(ROW_ADDR_FIELD.clone()));
707 }
708 target = Arc::new(project(target, &Schema::new(columns))?);
709
710 let column_names = schema
711 .field_names()
712 .into_iter()
713 .filter(|name| name.as_str() != ROW_ID && name.as_str() != ROW_ADDR)
714 .collect::<Vec<_>>();
715
716 let unindexed_fragments = self.dataset.unindexed_fragments(&index.name).await?;
718 if !unindexed_fragments.is_empty() {
719 let mut builder = self.dataset.scan();
720 if add_row_addr {
721 builder.with_row_address();
722 }
723 let unindexed_data = builder
724 .with_row_id()
725 .with_fragments(unindexed_fragments)
726 .project(&column_names)
727 .unwrap()
728 .create_plan()
729 .await?;
730 let unioned = UnionExec::try_new(vec![target, unindexed_data])?;
731 target = Arc::new(RepartitionExec::try_new(
733 unioned,
734 datafusion::physical_plan::Partitioning::RoundRobinBatch(1),
735 )?);
736 }
737
738 target = Self::prefix_columns_phys(target, "target_");
741
742 let source_key = Column::new_with_schema(&index_column, shared_input.schema().as_ref())?;
744 let target_key = Column::new_with_schema(
745 &format!("target_{}", index_column),
746 target.schema().as_ref(),
747 )?;
748 let joined = Arc::new(
749 HashJoinExec::try_new(
750 shared_input,
751 target,
752 vec![(Arc::new(source_key), Arc::new(target_key))],
753 None,
754 &JoinType::Full,
755 None,
756 PartitionMode::CollectLeft,
757 NullEquality::NullEqualsNull,
758 false,
759 )
760 .unwrap(),
761 );
762 execute_plan(
763 joined,
764 LanceExecutionOptions {
765 use_spilling: true,
766 ..Default::default()
767 },
768 )
769 }
770
771 fn prefix_columns(df: DataFrame, prefix: &str) -> DataFrame {
772 let schema = df.schema();
773 let columns = schema
774 .fields()
775 .iter()
776 .map(|f| {
777 logical_expr::col(format!("\"{}\"", f.name())).alias(format!(
779 "{}{}",
780 prefix,
781 f.name()
782 ))
783 })
784 .collect::<Vec<_>>();
785 df.select(columns).unwrap()
786 }
787
788 fn prefix_columns_phys(inp: Arc<dyn ExecutionPlan>, prefix: &str) -> Arc<dyn ExecutionPlan> {
789 let schema = inp.schema();
790 let exprs = schema
791 .fields()
792 .iter()
793 .enumerate()
794 .map(|(idx, f)| {
795 let col = Arc::new(Column::new(f.name(), idx)) as Arc<dyn PhysicalExpr>;
796 let new_name = format!("{}{}", prefix, f.name());
797 (col, new_name)
798 })
799 .collect::<Vec<_>>();
800 Arc::new(ProjectionExec::try_new(exprs, inp).unwrap())
801 }
802
803 async fn create_full_table_joined_stream(
805 &self,
806 source: SendableRecordBatchStream,
807 ) -> Result<SendableRecordBatchStream> {
808 let session_config = SessionConfig::default().with_target_partitions(1);
809 let session_ctx = SessionContext::new_with_config(session_config);
810 let schema = source.schema();
811 let new_data = session_ctx.read_one_shot(source)?;
812 let join_cols = self
813 .params
814 .on .iter()
816 .map(|c| c.as_str())
817 .collect::<Vec<_>>(); let target_cols = self
819 .params
820 .on
821 .iter()
822 .map(|c| format!("target_{}", c))
823 .collect::<Vec<_>>();
824 let target_cols = target_cols.iter().map(|s| s.as_str()).collect::<Vec<_>>();
825
826 match self.check_compatible_schema(&schema)? {
827 SchemaComparison::FullCompatible => {
828 let existing = session_ctx.read_lance(self.dataset.clone(), true, false)?;
829 let existing = Self::prefix_columns(existing, "target_");
831 let joined =
832 new_data.join(existing, JoinType::Full, &join_cols, &target_cols, None)?; Ok(joined.execute_stream().await?)
834 }
835 SchemaComparison::Subschema => {
836 let existing = session_ctx.read_lance(self.dataset.clone(), true, true)?;
837 let columns = schema
838 .field_names()
839 .iter()
840 .map(|s| s.as_str())
841 .chain([ROW_ID, ROW_ADDR])
842 .collect::<Vec<_>>();
843 let projected = existing.select_columns(&columns)?;
844 let projected = Self::prefix_columns(projected, "target_");
846 let join_type = if self.params.insert_not_matched {
848 JoinType::Left
849 } else {
850 JoinType::Inner
851 };
852 let joined = new_data.join(projected, join_type, &join_cols, &target_cols, None)?;
853 Ok(joined.execute_stream().await?)
854 }
855 }
856 }
857
858 async fn create_joined_stream(
866 &self,
867 source: SendableRecordBatchStream,
868 ) -> Result<SendableRecordBatchStream> {
869 if self.params.use_index
870 && matches!(
871 self.params.delete_not_matched_by_source,
872 WhenNotMatchedBySource::Keep
873 )
874 {
875 if let Some(index) = self.join_key_as_scalar_index().await? {
877 return self.create_indexed_scan_joined_stream(source, index).await;
878 }
879 }
880
881 if !matches!(
882 self.params.delete_not_matched_by_source,
883 WhenNotMatchedBySource::Keep
884 ) {
885 info!(
886 "The merge insert operation is configured to delete rows from the target table, this requires a potentially costly full table scan"
887 );
888 }
889
890 self.create_full_table_joined_stream(source).await
891 }
892
893 async fn update_fragments(
894 dataset: Arc<Dataset>,
895 source: SendableRecordBatchStream,
896 current_version: u64,
897 ) -> Result<(Vec<Fragment>, Vec<Fragment>, Vec<u32>)> {
898 use datafusion::logical_expr::{col, lit};
900 let session_ctx = get_session_context(&LanceExecutionOptions {
901 use_spilling: true,
902 target_partition: Some(get_num_compute_intensive_cpus().min(8)),
903 ..Default::default()
904 });
905 const MAX_BATCH_BYTES: usize = 25 * 1024 * 1024;
909 let sorted = session_ctx
910 .read_one_shot(source)?
911 .with_column("_fragment_id", col(ROW_ADDR) >> lit(32))?
912 .sort(vec![col(ROW_ADDR).sort(true, true)])?;
913 let sorted_plan = sorted.create_physical_plan().await?;
914 let capped_plan = sorted_plan
917 .transform_down(|node| {
918 if node.as_any().downcast_ref::<SortExec>().is_some() {
919 let children = node.children();
920 let new_children: Vec<Arc<dyn ExecutionPlan>> = children
921 .into_iter()
922 .map(|c| {
923 Arc::new(HardCapBatchSizeExec::new(c.clone(), MAX_BATCH_BYTES))
924 as Arc<dyn ExecutionPlan>
925 })
926 .collect();
927 let new_node = node.with_new_children(new_children)?;
928 Ok(Transformed::yes(new_node))
929 } else {
930 Ok(Transformed::no(node))
931 }
932 })?
933 .data;
934 let capped_stream = capped_plan.execute(0, session_ctx.task_ctx())?;
935 let mut group_stream = BatchStreamGrouper::new(capped_stream, "_fragment_id".into());
936
937 let updated_fragments = Arc::new(Mutex::new(Vec::new()));
939 let new_fragments = Arc::new(Mutex::new(Vec::new()));
940 let mut tasks = JoinSet::new();
941 let task_limit = dataset.object_store.as_ref().io_parallelism();
942 let reservation =
943 MemoryConsumer::new("MergeInsert").register(session_ctx.task_ctx().memory_pool());
944
945 while let Some((frag_id, batches)) = group_stream.next().await.transpose()? {
946 async fn handle_fragment(
947 dataset: Arc<Dataset>,
948 fragment: FileFragment,
949 mut metadata: Fragment,
950 mut batches: Vec<RecordBatch>,
951 updated_fragments: Arc<Mutex<Vec<Fragment>>>,
952 reservation_size: usize,
953 current_version: u64,
954 ) -> Result<usize> {
955 let write_schema = batches[0]
957 .schema()
958 .as_ref()
959 .without_column(ROW_ADDR)
960 .without_column(ROW_ID);
961 let write_schema = dataset.schema().project_by_schema(
962 &write_schema,
963 OnMissing::Error,
964 OnTypeMismatch::Error,
965 )?;
966
967 let updated_rows: usize = batches.iter().map(|batch| batch.num_rows()).sum();
968 if Some(updated_rows) == metadata.physical_rows {
969 let data_storage_version = dataset
975 .manifest()
976 .data_storage_format
977 .lance_file_version()?;
978 let mut writer = open_writer(
979 &dataset.object_store,
980 &write_schema,
981 &dataset.base,
982 data_storage_version,
983 )
984 .await?;
985
986 batches
988 .iter_mut()
989 .try_for_each(|batch| match batch.drop_column(ROW_ADDR) {
990 Ok(b) => {
991 *batch = b;
992 Ok(())
993 }
994 Err(e) => Err(e),
995 })?;
996
997 if data_storage_version == LanceFileVersion::Legacy {
998 let reader = fragment
1001 .open(
1002 dataset.schema(),
1003 FragReadConfig::default().with_row_address(true),
1004 )
1005 .await?;
1006 let batch_size = reader.legacy_num_rows_in_batch(0).unwrap();
1007 let stream = stream::iter(batches.into_iter().map(Ok));
1008 let stream = Box::pin(RecordBatchStreamAdapter::new(
1009 Arc::new((&write_schema).into()),
1010 stream,
1011 ));
1012 let mut stream = chunk_stream(stream, batch_size as usize);
1013 while let Some(chunk) = stream.next().await {
1014 writer.write(&chunk?).await?;
1015 }
1016 } else {
1017 writer.write(batches.as_slice()).await?;
1018 }
1019
1020 let (_num_rows, data_file) = writer.finish().await?;
1021
1022 metadata.files.push(data_file);
1023
1024 if dataset.manifest.uses_stable_row_ids() {
1025 lance_table::rowids::version::refresh_row_latest_update_meta_for_full_frag_rewrite_cols(
1027 &mut metadata,
1028 current_version,
1029 )?;
1030 }
1031
1032 updated_fragments.lock().unwrap().push(metadata);
1033 } else {
1034 let update_schema = batches[0].schema();
1036 let read_columns = update_schema.field_names();
1037 let mut updater = fragment
1038 .updater(
1039 Some(&read_columns),
1040 Some((write_schema, dataset.schema().clone())),
1041 None,
1042 )
1043 .await?;
1044
1045 let mut source_batches = Vec::with_capacity(batches.len() + 1);
1049 source_batches.push(batches[0].clone()); for batch in &batches {
1051 source_batches.push(batch.drop_column(ROW_ADDR)?);
1052 }
1053
1054 fn get_row_addr_iter(
1056 batches: &[RecordBatch],
1057 ) -> impl Iterator<Item = (u64, (usize, usize))> + '_ + Send
1058 {
1059 batches.iter().enumerate().flat_map(|(batch_idx, batch)| {
1060 let batch_idx = batch_idx + 1;
1062 let row_addrs = batch
1063 .column_by_name(ROW_ADDR)
1064 .unwrap()
1065 .as_any()
1066 .downcast_ref::<UInt64Array>()
1067 .unwrap();
1068 row_addrs
1069 .values()
1070 .iter()
1071 .enumerate()
1072 .map(move |(offset, row_addr)| (*row_addr, (batch_idx, offset)))
1073 })
1074 }
1075 let mut updated_row_addr_iter = get_row_addr_iter(&batches).peekable();
1076
1077 while let Some(batch) = updater.next().await? {
1078 source_batches[0] =
1079 batch.project_by_schema(source_batches[1].schema().as_ref())?;
1080
1081 let original_row_addrs = batch
1082 .column_by_name(ROW_ADDR)
1083 .unwrap()
1084 .as_any()
1085 .downcast_ref::<UInt64Array>()
1086 .unwrap();
1087 let indices = original_row_addrs
1088 .values()
1089 .into_iter()
1090 .enumerate()
1091 .map(|(original_offset, row_addr)| {
1092 match updated_row_addr_iter.peek() {
1093 Some((updated_row_addr, _))
1094 if *updated_row_addr == *row_addr =>
1095 {
1096 updated_row_addr_iter.next().unwrap().1
1097 }
1098 Some((updated_row_addr, _)) => {
1100 debug_assert!(
1101 *updated_row_addr > *row_addr,
1102 "Got updated row address that is not in the original batch"
1103 );
1104 (0, original_offset)
1105 }
1106 _ => (0, original_offset),
1107 }
1108 })
1109 .collect::<Vec<_>>();
1110
1111 let updated_batch = interleave_batches(&source_batches, &indices)?;
1112
1113 updater.update(updated_batch).await?;
1114 }
1115
1116 let mut updated_fragment = updater.finish().await?;
1117
1118 if dataset.manifest.uses_stable_row_ids() {
1119 let mut updated_offsets: Vec<usize> = Vec::new();
1122 for b in batches.iter() {
1123 let row_addrs = b
1124 .column_by_name(ROW_ADDR)
1125 .unwrap()
1126 .as_any()
1127 .downcast_ref::<UInt64Array>()
1128 .unwrap();
1129 updated_offsets.extend(
1130 row_addrs
1131 .values()
1132 .iter()
1133 .map(|addr| RowAddress::from(*addr).row_offset() as usize),
1134 );
1135 }
1136 updated_offsets.sort_unstable();
1137 updated_offsets.dedup();
1138
1139 lance_table::rowids::version::refresh_row_latest_update_meta_for_partial_frag_rewrite_cols(
1140 &mut updated_fragment,
1141 &updated_offsets,
1142 current_version,
1143 dataset.manifest.version,
1144 )?;
1145 }
1146
1147 updated_fragments.lock().unwrap().push(updated_fragment);
1148 }
1149 Ok(reservation_size)
1150 }
1151
1152 async fn handle_new_fragments(
1153 dataset: Arc<Dataset>,
1154 batches: Vec<RecordBatch>,
1155 new_fragments: Arc<Mutex<Vec<Fragment>>>,
1156 reservation_size: usize,
1157 ) -> Result<usize> {
1158 let num_fields = batches[0].schema().fields().len();
1161 let mut projection = Vec::with_capacity(num_fields - 1);
1162 for (i, field) in batches[0].schema().fields().iter().enumerate() {
1163 if field.name() != ROW_ADDR {
1164 projection.push(i);
1165 }
1166 }
1167 let write_schema = Arc::new(batches[0].schema().project(&projection).unwrap());
1168
1169 let batches = batches
1170 .into_iter()
1171 .map(move |batch| batch.project(&projection));
1172 let reader = RecordBatchIterator::new(batches, write_schema.clone());
1173 let stream = reader_to_stream(Box::new(reader));
1174
1175 let write_schema = dataset.schema().project_by_schema(
1176 write_schema.as_ref(),
1177 OnMissing::Error,
1178 OnTypeMismatch::Error,
1179 )?;
1180
1181 let (fragments, _) = write_fragments_internal(
1182 Some(dataset.as_ref()),
1183 dataset.object_store.clone(),
1184 &dataset.base,
1185 write_schema,
1186 stream,
1187 Default::default(), None, )
1190 .await?;
1191
1192 new_fragments.lock().unwrap().extend(fragments);
1193 Ok(reservation_size)
1194 }
1195 let mut memory_size = batches
1197 .iter()
1198 .map(|batch| batch.get_array_memory_size())
1199 .sum();
1200
1201 loop {
1202 let have_additional_cpus = tasks.len() < task_limit;
1203 if have_additional_cpus {
1204 if reservation.try_grow(memory_size).is_ok() {
1205 break;
1206 } else if tasks.is_empty() {
1207 memory_size = 0;
1210 break;
1211 }
1212 }
1214
1215 if let Some(res) = tasks.join_next().await {
1216 let size = res??;
1217 reservation.shrink(size);
1218 }
1219 }
1220
1221 match frag_id.first() {
1222 Some(ScalarValue::UInt64(Some(frag_id))) => {
1223 let frag_id = *frag_id;
1224 let fragment = dataset.get_fragment(frag_id as usize).ok_or_else(|| {
1225 error!(
1226 fragment_id = frag_id,
1227 dataset_uri = %dataset.uri(),
1228 manifest_version = dataset.manifest().version,
1229 manifest_path = %dataset.manifest_location().path,
1230 branch = ?dataset.manifest().branch,
1231 "Non-existent fragment id returned from merge result",
1232 );
1233 Error::internal(format!(
1234 "Got non-existent fragment id from merge result: {} (uri={}, version={}, manifest={}, branch={})",
1235 frag_id,
1236 dataset.uri(),
1237 dataset.manifest().version,
1238 dataset.manifest_location().path,
1239 dataset.manifest().branch.as_deref().unwrap_or("main"),
1240 ))
1241 })?;
1242 let metadata = fragment.metadata.clone();
1243
1244 let fut = handle_fragment(
1245 dataset.clone(),
1246 fragment,
1247 metadata,
1248 batches,
1249 updated_fragments.clone(),
1250 memory_size,
1251 current_version,
1252 );
1253 tasks.spawn(fut);
1254 }
1255 Some(ScalarValue::Null | ScalarValue::UInt64(None)) => {
1256 let fut = handle_new_fragments(
1257 dataset.clone(),
1258 batches,
1259 new_fragments.clone(),
1260 memory_size,
1261 );
1262 tasks.spawn(fut);
1263 }
1264 _ => {
1265 return Err(Error::internal(format!(
1266 "Got non-fragment id from merge result: {:?}",
1267 frag_id
1268 )));
1269 }
1270 };
1271 }
1272
1273 while let Some(res) = tasks.join_next().await {
1274 let size = res??;
1275 reservation.shrink(size);
1276 }
1277 let mut updated_fragments = Arc::try_unwrap(updated_fragments)
1278 .unwrap()
1279 .into_inner()
1280 .unwrap();
1281
1282 let mut all_fields_updated = HashSet::new();
1287
1288 for fragment in &mut updated_fragments {
1291 let updated_fields = fragment.files.last().unwrap().fields.clone();
1292 all_fields_updated.extend(updated_fields.iter().map(|&f| f as u32));
1293 for data_file in &mut fragment.files.iter_mut().rev().skip(1) {
1294 let new_fields: Arc<[i32]> = data_file
1295 .fields
1296 .iter()
1297 .map(|field| {
1298 if updated_fields.contains(field) {
1299 -2 } else {
1301 *field
1302 }
1303 })
1304 .collect::<Vec<_>>()
1305 .into();
1306 data_file.fields = new_fields;
1307 }
1308 }
1309
1310 let new_fragments = Arc::try_unwrap(new_fragments)
1311 .unwrap()
1312 .into_inner()
1313 .unwrap();
1314
1315 Ok((
1316 updated_fragments,
1317 new_fragments,
1318 all_fields_updated.into_iter().collect(),
1319 ))
1320 }
1321
1322 pub async fn execute(
1327 self,
1328 source: SendableRecordBatchStream,
1329 ) -> Result<(Arc<Dataset>, MergeStats)> {
1330 let source_iter = super::new_source_iter(source, self.params.conflict_retries > 0).await?;
1331 let dataset = self.dataset.clone();
1332 let config = RetryConfig {
1333 max_retries: self.params.conflict_retries,
1334 retry_timeout: self.params.retry_timeout,
1335 };
1336
1337 let wrapper = MergeInsertJobWithIterator {
1338 job: self,
1339 source_iter: Arc::new(Mutex::new(source_iter)),
1340 attempt_count: Arc::new(AtomicU32::new(0)),
1341 };
1342
1343 Box::pin(execute_with_retry(wrapper, dataset, config)).await
1344 }
1345
1346 pub async fn execute_uncommitted(
1350 self,
1351 source: impl StreamingWriteSource,
1352 ) -> Result<UncommittedMergeInsert> {
1353 let stream = source.into_stream();
1354 self.execute_uncommitted_impl(stream).await
1355 }
1356
1357 fn create_plan_join_type(&self) -> JoinType {
1358 let keep_unmatched_source_rows = self.params.insert_not_matched;
1359 let keep_unmatched_target_rows = !matches!(
1360 self.params.delete_not_matched_by_source,
1361 WhenNotMatchedBySource::Keep
1362 );
1363
1364 match (keep_unmatched_target_rows, keep_unmatched_source_rows) {
1365 (false, false) => JoinType::Inner,
1366 (false, true) => JoinType::Right,
1367 (true, false) => JoinType::Left,
1368 (true, true) => JoinType::Full,
1369 }
1370 }
1371
1372 async fn create_plan(
1373 self,
1374 source: SendableRecordBatchStream,
1375 ) -> Result<Arc<dyn ExecutionPlan>> {
1376 let session_config = SessionConfig::default();
1382 let session_ctx = SessionContext::new_with_config(session_config);
1383 let scan = session_ctx.read_lance_unordered(self.dataset.clone(), true, true)?;
1384 let on_cols = self
1386 .params
1387 .on
1388 .iter()
1389 .map(|name| format!("\"{}\"", name))
1390 .collect::<Vec<_>>();
1391 let on_cols_refs = on_cols.iter().map(|s| s.as_str()).collect::<Vec<_>>();
1392 let source_df = session_ctx.read_one_shot(source)?;
1393 let source_field_names: std::collections::HashSet<String> = source_df
1397 .schema()
1398 .fields()
1399 .iter()
1400 .map(|f| f.name().clone())
1401 .collect();
1402 let source_df = source_df
1407 .with_column(MERGE_SOURCE_SENTINEL, logical_expr::lit(true))
1408 .map_err(crate::Error::from)?;
1409 let source_df_aliased = source_df.alias("source")?;
1410 let scan_aliased = scan.alias("target")?;
1411 let join_type = self.create_plan_join_type();
1412 let dataset_schema: Schema = self.dataset.schema().into();
1413 let mut df = scan_aliased
1414 .join(
1415 source_df_aliased,
1416 join_type,
1417 &on_cols_refs,
1418 &on_cols_refs,
1419 None,
1420 )?
1421 .with_column(
1422 MERGE_ACTION_COLUMN,
1423 merge_insert_action(&self.params, Some(&dataset_schema))?,
1424 )?;
1425
1426 for field in dataset_schema.fields() {
1438 if !source_field_names.contains(field.name()) {
1439 df = df.with_column(
1440 field.name(),
1441 logical_expr::col(format!("target.\"{}\"", field.name())),
1442 )?;
1443 }
1444 }
1445
1446 let (session_state, logical_plan) = df.into_parts();
1447
1448 let write_node = logical_plan::MergeInsertWriteNode::new(
1449 logical_plan,
1450 self.dataset.clone(),
1451 self.params.clone(),
1452 );
1453 let logical_plan = LogicalPlan::Extension(Extension {
1454 node: Arc::new(write_node),
1455 });
1456
1457 let logical_plan = session_state.optimize(&logical_plan)?;
1458
1459 let planner =
1460 DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(MergeInsertPlanner {})]);
1461 let physical_plan = planner
1463 .create_physical_plan(&logical_plan, &session_state)
1464 .await?;
1465
1466 Ok(physical_plan)
1467 }
1468
1469 async fn execute_uncommitted_v2(
1470 self,
1471 source: SendableRecordBatchStream,
1472 ) -> Result<(
1473 Transaction,
1474 MergeStats,
1475 Option<RowAddrTreeMap>,
1476 Option<KeyExistenceFilter>,
1477 )> {
1478 let plan = self.create_plan(source).await?;
1479
1480 let partition_count = match plan.properties().output_partitioning() {
1483 datafusion_physical_expr::Partitioning::RoundRobinBatch(n) => *n,
1484 datafusion_physical_expr::Partitioning::Hash(_, n) => *n,
1485 datafusion_physical_expr::Partitioning::UnknownPartitioning(n) => *n,
1486 };
1487
1488 if partition_count != 1 {
1489 return Err(Error::invalid_input(format!(
1490 "Expected exactly 1 partition, got {}",
1491 partition_count
1492 )));
1493 }
1494
1495 let task_context = Arc::new(datafusion::execution::TaskContext::default());
1497 let mut stream = plan.execute(0, task_context)?;
1498
1499 if let Some(batch) = stream.next().await {
1501 let batch = batch?;
1502 if batch.num_rows() > 0 {
1503 return Err(Error::invalid_input(format!(
1504 "Expected no output from write operation, got {} rows",
1505 batch.num_rows()
1506 )));
1507 }
1508 }
1509
1510 let (stats, transaction, affected_rows, inserted_rows_filter) = if let Some(full_exec) =
1512 plan.as_any()
1513 .downcast_ref::<exec::FullSchemaMergeInsertExec>()
1514 {
1515 let stats = full_exec.merge_stats().ok_or_else(|| {
1516 Error::internal("Merge stats not available - execution may not have completed")
1517 })?;
1518 let transaction = full_exec.transaction().ok_or_else(|| {
1519 Error::internal("Transaction not available - execution may not have completed")
1520 })?;
1521 let affected_rows = full_exec.affected_rows().map(RowAddrTreeMap::from);
1522 let inserted_rows_filter = full_exec.inserted_rows_filter();
1523 (stats, transaction, affected_rows, inserted_rows_filter)
1524 } else if let Some(delete_exec) = plan
1525 .as_any()
1526 .downcast_ref::<exec::DeleteOnlyMergeInsertExec>()
1527 {
1528 let stats = delete_exec.merge_stats().ok_or_else(|| {
1529 Error::internal("Merge stats not available - execution may not have completed")
1530 })?;
1531 let transaction = delete_exec.transaction().ok_or_else(|| {
1532 Error::internal("Transaction not available - execution may not have completed")
1533 })?;
1534 let affected_rows = delete_exec.affected_rows().map(RowAddrTreeMap::from);
1535 (stats, transaction, affected_rows, None)
1536 } else {
1537 return Err(Error::internal(
1538 "Expected FullSchemaMergeInsertExec or DeleteOnlyMergeInsertExec",
1539 ));
1540 };
1541
1542 Ok((transaction, stats, affected_rows, inserted_rows_filter))
1543 }
1544
1545 async fn can_use_create_plan(&self, source_schema: &Schema) -> Result<bool> {
1560 let lance_schema = lance_core::datatypes::Schema::try_from(source_schema)?;
1562 let full_schema = self.dataset.schema();
1563 let is_full_schema = full_schema.compare_with_options(
1564 &lance_schema,
1565 &SchemaCompareOptions {
1566 compare_metadata: false,
1567 compare_nullability: NullabilityComparison::Ignore,
1569 ignore_field_order: true,
1571 ..Default::default()
1572 },
1573 );
1574
1575 let is_subset_schema = !is_full_schema
1579 && lance_schema.fields.iter().all(|sf| {
1580 full_schema
1581 .field(&sf.name)
1582 .map(|tf| tf.data_type() == sf.data_type())
1583 .unwrap_or(false)
1584 });
1585
1586 if is_subset_schema && self.params.insert_not_matched {
1591 let non_nullable_missing: Vec<&str> = full_schema
1592 .fields
1593 .iter()
1594 .filter(|tf| lance_schema.field(&tf.name).is_none() && !tf.nullable)
1595 .map(|tf| tf.name.as_str())
1596 .collect();
1597 if !non_nullable_missing.is_empty() {
1598 return Err(Error::invalid_input(format!(
1599 "Cannot insert rows with a partial-schema source: target column(s) \
1600 {:?} are non-nullable and not provided by the source. Either add \
1601 them to the source or set when_not_matched to DoNothing.",
1602 non_nullable_missing
1603 )));
1604 }
1605 }
1606
1607 let would_use_scalar_index = if self.params.use_index
1608 && matches!(
1609 self.params.delete_not_matched_by_source,
1610 WhenNotMatchedBySource::Keep
1611 ) {
1612 self.join_key_as_scalar_index().await?.is_some()
1613 } else {
1614 false
1615 };
1616
1617 let no_upsert = matches!(
1620 self.params.when_matched,
1621 WhenMatched::Delete | WhenMatched::DoNothing
1622 ) && !self.params.insert_not_matched;
1623
1624 let source_has_key_columns = self.params.on.iter().all(|key| {
1626 source_schema
1627 .fields()
1628 .iter()
1629 .any(|f| f.name() == key.as_str())
1630 });
1631 let schema_ok = is_full_schema || is_subset_schema || (no_upsert && source_has_key_columns);
1632
1633 Ok(matches!(
1634 self.params.when_matched,
1635 WhenMatched::UpdateAll
1636 | WhenMatched::UpdateIf(_)
1637 | WhenMatched::Fail
1638 | WhenMatched::Delete
1639 | WhenMatched::DoNothing
1640 ) && !would_use_scalar_index
1641 && schema_ok
1642 && matches!(
1643 self.params.delete_not_matched_by_source,
1644 WhenNotMatchedBySource::Keep
1645 | WhenNotMatchedBySource::Delete
1646 | WhenNotMatchedBySource::DeleteIf(_)
1647 ))
1648 }
1649
1650 async fn execute_uncommitted_impl(
1651 self,
1652 source: SendableRecordBatchStream,
1653 ) -> Result<UncommittedMergeInsert> {
1654 let can_use_fast_path = self.can_use_create_plan(source.schema().as_ref()).await?;
1656
1657 if can_use_fast_path {
1658 let (transaction, stats, affected_rows, inserted_rows_filter) =
1659 self.execute_uncommitted_v2(source).await?;
1660 return Ok(UncommittedMergeInsert {
1661 transaction,
1662 affected_rows,
1663 stats,
1664 inserted_rows_filter,
1665 });
1666 }
1667
1668 let source_schema = source.schema();
1669 let lance_schema = lance_core::datatypes::Schema::try_from(source_schema.as_ref())?;
1670 let full_schema = self.dataset.schema();
1671 let is_full_schema = full_schema.compare_with_options(
1672 &lance_schema,
1673 &SchemaCompareOptions {
1674 compare_metadata: false,
1675 compare_nullability: NullabilityComparison::Ignore,
1677 ..Default::default()
1678 },
1679 );
1680 let joined = self.create_joined_stream(source).await?;
1681 let merger = Merger::try_new(
1682 self.params.clone(),
1683 source_schema,
1684 !is_full_schema,
1685 self.dataset.manifest.uses_stable_row_ids(),
1686 )?;
1687 let merge_statistics = merger.merge_stats.clone();
1688 let deleted_rows = merger.deleted_rows.clone();
1689 let updating_row_ids = merger.updating_row_ids.clone();
1690 let merger_schema = merger.output_schema().clone();
1691 let stream = joined
1692 .and_then(move |batch| merger.clone().execute_batch(batch))
1693 .try_flatten();
1694 let stream = RecordBatchStreamAdapter::new(merger_schema, stream);
1695
1696 let (operation, affected_rows) = if !is_full_schema {
1697 if !matches!(
1698 self.params.delete_not_matched_by_source,
1699 WhenNotMatchedBySource::Keep
1700 ) {
1701 return Err(Error::not_supported_source("Deleting rows from the target table when there is no match in the source table is not supported when the source data has a different schema than the target data".into()));
1702 }
1703
1704 let (updated_fragments, new_fragments, fields_modified) = Self::update_fragments(
1707 self.dataset.clone(),
1708 Box::pin(stream),
1709 self.dataset.manifest.version + 1,
1710 )
1711 .await?;
1712
1713 let operation = Operation::Update {
1714 removed_fragment_ids: Vec::new(),
1715 updated_fragments,
1716 new_fragments,
1717 fields_modified,
1718 merged_generations: self.params.merged_generations.clone(),
1719 fields_for_preserving_frag_bitmap: vec![], update_mode: Some(RewriteColumns),
1721 inserted_rows_filter: None, updated_fragment_offsets: None,
1723 };
1724 (operation, None)
1727 } else {
1728 let (mut new_fragments, _) = write_fragments_internal(
1729 Some(&self.dataset),
1730 self.dataset.object_store.clone(),
1731 &self.dataset.base,
1732 self.dataset.schema().clone(),
1733 Box::pin(stream),
1734 WriteParams::default(),
1735 None, )
1737 .await?;
1738
1739 if let Some(row_id_sequence) = updating_row_ids.lock().unwrap().row_id_sequence() {
1740 let fragment_sizes = new_fragments
1741 .iter()
1742 .map(|f| f.physical_rows.unwrap() as u64);
1743
1744 let sequences = lance_table::rowids::rechunk_sequences(
1745 [row_id_sequence.clone()],
1746 fragment_sizes,
1747 true,
1748 )
1749 .map_err(|e| {
1750 Error::internal(format!(
1751 "Captured row ids not equal to number of rows written: {}",
1752 e
1753 ))
1754 })?;
1755
1756 for (fragment, sequence) in new_fragments.iter_mut().zip(sequences) {
1757 let serialized = lance_table::rowids::write_row_ids(&sequence);
1758 fragment.row_id_meta = Some(RowIdMeta::Inline(serialized));
1759 }
1760 }
1761
1762 let removed_row_ids = Arc::into_inner(deleted_rows).unwrap().into_inner().unwrap();
1764
1765 let removed_row_addr_vec =
1766 if let Some(row_id_index) = get_row_id_index(&self.dataset).await? {
1767 let addresses: Vec<u64> = removed_row_ids
1768 .iter()
1769 .filter_map(|id| row_id_index.get(*id).map(|address| address.into()))
1770 .collect::<Vec<_>>();
1771 addresses
1772 } else {
1773 removed_row_ids
1774 };
1775
1776 let removed_row_addrs = RoaringTreemap::from_iter(removed_row_addr_vec.into_iter());
1777
1778 let (old_fragments, removed_fragment_ids) =
1779 Self::apply_deletions(&self.dataset, &removed_row_addrs).await?;
1780
1781 let operation = Operation::Update {
1783 removed_fragment_ids,
1784 updated_fragments: old_fragments,
1785 new_fragments,
1786 fields_modified: vec![],
1789 merged_generations: self.params.merged_generations.clone(),
1790 fields_for_preserving_frag_bitmap: full_schema
1791 .fields
1792 .iter()
1793 .map(|f| f.id as u32)
1794 .collect(),
1795 update_mode: Some(RewriteRows),
1796 inserted_rows_filter: None, updated_fragment_offsets: None,
1798 };
1799
1800 let affected_rows = Some(RowAddrTreeMap::from(removed_row_addrs));
1801 (operation, affected_rows)
1802 };
1803
1804 let stats = Arc::into_inner(merge_statistics)
1805 .unwrap()
1806 .into_inner()
1807 .unwrap();
1808
1809 let transaction = Transaction::new(self.dataset.manifest.version, operation, None);
1810
1811 Ok(UncommittedMergeInsert {
1812 transaction,
1813 affected_rows,
1814 stats,
1815 inserted_rows_filter: None, })
1817 }
1818
1819 async fn apply_deletions(
1821 dataset: &Dataset,
1822 removed_row_ids: &RoaringTreemap,
1823 ) -> Result<(Vec<Fragment>, Vec<u64>)> {
1824 let bitmaps = Arc::new(removed_row_ids.bitmaps().collect::<BTreeMap<_, _>>());
1825
1826 enum FragmentChange {
1827 Unchanged,
1828 Modified(Box<Fragment>),
1829 Removed(u64),
1830 }
1831
1832 let mut updated_fragments = Vec::new();
1833 let mut removed_fragments = Vec::new();
1834
1835 let mut stream = futures::stream::iter(dataset.get_fragments())
1836 .map(move |fragment| {
1837 let bitmaps_ref = bitmaps.clone();
1838 async move {
1839 let fragment_id = fragment.id();
1840 if let Some(bitmap) = bitmaps_ref.get(&(fragment_id as u32)) {
1841 match fragment.extend_deletions(*bitmap).await {
1842 Ok(Some(new_fragment)) => {
1843 Ok(FragmentChange::Modified(Box::new(new_fragment.metadata)))
1844 }
1845 Ok(None) => Ok(FragmentChange::Removed(fragment_id as u64)),
1846 Err(e) => Err(e),
1847 }
1848 } else {
1849 Ok(FragmentChange::Unchanged)
1850 }
1851 }
1852 })
1853 .buffer_unordered(dataset.object_store.io_parallelism());
1854
1855 while let Some(res) = stream.next().await.transpose()? {
1856 match res {
1857 FragmentChange::Unchanged => {}
1858 FragmentChange::Modified(fragment) => updated_fragments.push(*fragment),
1859 FragmentChange::Removed(fragment_id) => removed_fragments.push(fragment_id),
1860 }
1861 }
1862
1863 Ok((updated_fragments, removed_fragments))
1864 }
1865
1866 pub async fn explain_plan(&self, schema: Option<&Schema>, verbose: bool) -> Result<String> {
1882 let schema = match schema {
1884 Some(s) => s.clone(),
1885 None => arrow_schema::Schema::from(self.dataset.schema()),
1886 };
1887
1888 if !self.can_use_create_plan(&schema).await? {
1890 return Err(Error::not_supported_source("This merge insert configuration does not support explain_plan. Only full-schema merge insert operations without a scalar-index execution path are currently supported.".into()));
1891 }
1892
1893 let empty_batch = RecordBatch::new_empty(Arc::new(schema.clone()));
1895 let stream = RecordBatchStreamAdapter::new(
1896 Arc::new(schema.clone()),
1897 futures::stream::once(async { Ok(empty_batch) }).boxed(),
1898 );
1899
1900 let cloned_job = self.clone();
1902 let plan = cloned_job.create_plan(Box::pin(stream)).await?;
1903 let display = DisplayableExecutionPlan::new(plan.as_ref());
1904
1905 Ok(format!("{}", display.indent(verbose)))
1906 }
1907
1908 pub async fn analyze_plan(&self, source: SendableRecordBatchStream) -> Result<String> {
1928 if !self.can_use_create_plan(source.schema().as_ref()).await? {
1930 return Err(Error::not_supported_source("This merge insert configuration does not support analyze_plan. Only full-schema merge insert operations without a scalar-index execution path are currently supported.".into()));
1931 }
1932
1933 let cloned_job = self.clone();
1935 let plan = cloned_job.create_plan(source).await?;
1936
1937 let options = LanceExecutionOptions::default();
1939 let full_analysis = analyze_plan(plan, options).await?;
1940
1941 let lines: Vec<&str> = full_analysis.lines().collect();
1943 let filtered_lines: Vec<&str> = lines
1944 .into_iter()
1945 .filter(|line| {
1946 !line.trim_start().starts_with("AnalyzeExec")
1947 && !line.trim_start().starts_with("TracedExec")
1948 })
1949 .collect();
1950
1951 Ok(filtered_lines.join("\n"))
1952 }
1953}
1954
1955#[derive(Debug, Default, Clone)]
1957pub struct MergeStats {
1958 pub num_inserted_rows: u64,
1960 pub num_updated_rows: u64,
1962 pub num_deleted_rows: u64,
1966 pub num_attempts: u32,
1970 pub bytes_written: u64,
1972 pub num_files_written: u64,
1974 pub num_skipped_duplicates: u64,
1976}
1977
1978pub struct UncommittedMergeInsert {
1979 pub transaction: Transaction,
1980 pub affected_rows: Option<RowAddrTreeMap>,
1981 pub stats: MergeStats,
1982 pub inserted_rows_filter: Option<KeyExistenceFilter>,
1983}
1984
1985#[derive(Clone)]
1987struct MergeInsertJobWithIterator {
1988 job: MergeInsertJob,
1989 source_iter: Arc<Mutex<Box<dyn Iterator<Item = SendableRecordBatchStream> + Send + 'static>>>,
1990 attempt_count: Arc<AtomicU32>,
1991}
1992
1993impl RetryExecutor for MergeInsertJobWithIterator {
1994 type Data = UncommittedMergeInsert;
1995 type Result = (Arc<Dataset>, MergeStats);
1996
1997 async fn execute_impl(&self) -> Result<Self::Data> {
1998 self.attempt_count.fetch_add(1, Ordering::SeqCst);
2000
2001 let stream = self.source_iter.lock().unwrap().next().unwrap();
2004 self.job.clone().execute_uncommitted_impl(stream).await
2005 }
2006
2007 async fn commit(&self, dataset: Arc<Dataset>, mut data: Self::Data) -> Result<Self::Result> {
2008 data.stats.num_attempts = self.attempt_count.load(Ordering::SeqCst);
2010
2011 let mut commit_builder =
2012 CommitBuilder::new(dataset).with_skip_auto_cleanup(self.job.params.skip_auto_cleanup);
2013 if let Some(commit_retries) = self.job.params.commit_retries {
2014 commit_builder = commit_builder.with_max_retries(commit_retries);
2015 }
2016 if let Some(affected_rows) = data.affected_rows {
2017 commit_builder = commit_builder.with_affected_rows(affected_rows);
2018 }
2019 let new_dataset = commit_builder.execute(data.transaction).await?;
2020
2021 Ok((Arc::new(new_dataset), data.stats))
2022 }
2023
2024 fn update_dataset(&mut self, dataset: Arc<Dataset>) {
2025 self.job.dataset = dataset;
2026 }
2027}
2028
2029#[derive(Debug, Clone)]
2034struct Merger {
2035 deleted_rows: Arc<Mutex<Vec<u64>>>,
2037 updating_row_ids: Arc<Mutex<CapturedRowIds>>,
2039 delete_expr: Option<Arc<dyn PhysicalExpr>>,
2041 merge_stats: Arc<Mutex<MergeStats>>,
2043 match_filter_expr: Option<Arc<dyn PhysicalExpr>>,
2045 params: MergeInsertParams,
2047 schema: Arc<Schema>,
2049 with_row_addr: bool,
2051 output_schema: Arc<Schema>,
2053 enable_stable_row_ids: bool,
2055 processed_row_ids: Arc<Mutex<HashSet<u64>>>,
2057}
2058
2059impl Merger {
2060 fn try_new(
2062 params: MergeInsertParams,
2063 schema: Arc<Schema>,
2064 with_row_addr: bool,
2065 enable_stable_row_ids: bool,
2066 ) -> Result<Self> {
2067 let delete_expr = if let WhenNotMatchedBySource::DeleteIf(expr) =
2068 ¶ms.delete_not_matched_by_source
2069 {
2070 let planner = Planner::new(schema.clone());
2071 let expr = planner.optimize_expr(expr.clone())?;
2072 let physical_expr = planner.create_physical_expr(&expr)?;
2073 let data_type = physical_expr.data_type(&schema)?;
2074 if data_type != DataType::Boolean {
2075 return Err(Error::invalid_input(format!(
2076 "Merge insert conditions must be expressions that return a boolean value, received expression ({}) which has data type {}",
2077 expr, data_type
2078 )));
2079 }
2080 Some(physical_expr)
2081 } else {
2082 None
2083 };
2084 let match_filter_expr = if let WhenMatched::UpdateIf(expr_str) = ¶ms.when_matched {
2085 let combined_schema = Arc::new(combined_schema(&schema));
2086 let planner = Planner::new(combined_schema.clone());
2087 let expr = planner.parse_filter(expr_str)?;
2088 let expr = planner.optimize_expr(expr)?;
2089 let match_expr = planner.create_physical_expr(&expr)?;
2090 let data_type = match_expr.data_type(combined_schema.as_ref())?;
2091 if data_type != DataType::Boolean {
2092 return Err(Error::invalid_input(format!(
2093 "Merge insert conditions must be expressions that return a boolean value, received a 'when matched update if' expression ({}) which has data type {}",
2094 expr, data_type
2095 )));
2096 }
2097 Some(match_expr)
2098 } else {
2099 None
2100 };
2101 let output_schema = if with_row_addr {
2102 Arc::new(schema.try_with_column(ROW_ADDR_FIELD.clone())?)
2103 } else {
2104 schema.clone()
2105 };
2106
2107 Ok(Self {
2108 deleted_rows: Arc::new(Mutex::new(Vec::new())),
2109 updating_row_ids: Arc::new(Mutex::new(CapturedRowIds::new(enable_stable_row_ids))),
2110 delete_expr,
2111 merge_stats: Arc::new(Mutex::new(MergeStats::default())),
2112 match_filter_expr,
2113 params,
2114 schema,
2115 with_row_addr,
2116 output_schema,
2117 enable_stable_row_ids,
2118 processed_row_ids: Arc::new(Mutex::new(HashSet::new())),
2119 })
2120 }
2121
2122 fn output_schema(&self) -> &Arc<Schema> {
2123 &self.output_schema
2124 }
2125
2126 fn not_all_null(
2130 batch: &RecordBatch,
2131 col_offset: usize,
2132 num_cols: usize,
2133 ) -> Result<BooleanArray> {
2134 debug_assert_ne!(num_cols, 0);
2136 let mut at_least_one_valid = arrow::compute::is_not_null(batch.column(col_offset))?;
2137 for idx in col_offset + 1..col_offset + num_cols {
2138 let is_valid = arrow::compute::is_not_null(batch.column(idx))?;
2139 at_least_one_valid = arrow::compute::or(&at_least_one_valid, &is_valid)?;
2140 }
2141 Ok(at_least_one_valid)
2142 }
2143
2144 fn extract_selections(
2160 &self,
2161 combined_batch: &RecordBatch,
2162 right_offset: usize,
2163 num_keys: usize,
2164 ) -> Result<(BooleanArray, BooleanArray, BooleanArray)> {
2165 let in_left = Self::not_all_null(combined_batch, 0, num_keys)?;
2166 let in_right = Self::not_all_null(combined_batch, right_offset, num_keys)?;
2167 let in_both = arrow::compute::and(&in_left, &in_right)?;
2168 let left_only = arrow::compute::and(&in_left, &arrow::compute::not(&in_right)?)?;
2169 let right_only = arrow::compute::and(&arrow::compute::not(&in_left)?, &in_right)?;
2170 Ok((left_only, in_both, right_only))
2171 }
2172
2173 async fn execute_batch(
2180 self,
2181 batch: RecordBatch,
2182 ) -> datafusion::common::Result<impl Stream<Item = datafusion::common::Result<RecordBatch>>>
2183 {
2184 let mut merge_statistics = self.merge_stats.lock().unwrap();
2185 let num_fields = batch.schema().fields.len();
2186 let (row_id_col, row_addr_col, right_offset) = if num_fields % 2 == 1 {
2190 assert!(!self.with_row_addr);
2192 (num_fields - 1, None, num_fields / 2)
2193 } else {
2194 assert!(self.with_row_addr);
2196 (num_fields - 2, Some(num_fields - 1), (num_fields - 2) / 2)
2197 };
2198
2199 let num_keys = self.params.on.len();
2200
2201 let left_cols = Vec::from_iter(0..right_offset);
2202 let right_cols_with_id = Vec::from_iter(right_offset..num_fields);
2203
2204 let mut batches = Vec::with_capacity(2);
2205 let (left_only, in_both, right_only) =
2206 self.extract_selections(&batch, right_offset, num_keys)?;
2207
2208 let mut deleted_row_ids = self.deleted_rows.lock().unwrap();
2211
2212 if self.params.when_matched != WhenMatched::DoNothing {
2213 let mut matched = arrow::compute::filter_record_batch(&batch, &in_both)?;
2214
2215 if let Some(match_filter) = self.match_filter_expr {
2216 let unzipped = unzip_batch(&matched, &self.schema);
2217 let filtered = match_filter.evaluate(&unzipped)?;
2218 match filtered {
2219 ColumnarValue::Array(mask) => {
2220 matched = arrow::compute::filter_record_batch(&matched, mask.as_boolean())?;
2222 }
2223 ColumnarValue::Scalar(scalar) => {
2224 if let ScalarValue::Boolean(Some(true)) = scalar {
2225 } else {
2227 matched = RecordBatch::new_empty(matched.schema());
2229 }
2230 }
2231 }
2232 }
2233
2234 merge_statistics.num_updated_rows += matched.num_rows() as u64;
2235
2236 if matched.num_rows() > 0 {
2239 let row_ids = matched.column(row_id_col).as_primitive::<UInt64Type>();
2240
2241 let mut processed_row_ids = self.processed_row_ids.lock().unwrap();
2242 let mut keep_indices: Vec<u32> = Vec::with_capacity(matched.num_rows());
2243 for (row_idx, &row_id) in row_ids.values().iter().enumerate() {
2244 if processed_row_ids.insert(row_id) {
2245 keep_indices.push(row_idx as u32);
2246 } else {
2247 match self.params.source_dedupe_behavior {
2248 SourceDedupeBehavior::Fail => {
2249 return Err(create_duplicate_row_error(
2250 &matched,
2251 row_idx,
2252 &self.params.on,
2253 ));
2254 }
2255 SourceDedupeBehavior::FirstSeen => {
2256 }
2258 }
2259 }
2260 }
2261 drop(processed_row_ids);
2262
2263 let num_skipped = matched.num_rows() - keep_indices.len();
2265 if num_skipped > 0 {
2266 merge_statistics.num_skipped_duplicates += num_skipped as u64;
2267 merge_statistics.num_updated_rows -= num_skipped as u64;
2268
2269 let indices = UInt32Array::from(keep_indices);
2270 matched = take_record_batch(&matched, &indices)?;
2271 }
2272
2273 if matched.num_rows() > 0 {
2275 let row_ids = matched.column(row_id_col).as_primitive::<UInt64Type>();
2277 deleted_row_ids.extend(row_ids.values());
2278 if self.enable_stable_row_ids {
2279 self.updating_row_ids
2280 .lock()
2281 .unwrap()
2282 .capture(row_ids.values())?;
2283 }
2284
2285 let projection = if let Some(row_addr_col) = row_addr_col {
2286 let mut cols = Vec::from_iter(left_cols.iter().cloned());
2287 cols.push(row_addr_col);
2288 cols
2289 } else {
2290 #[allow(clippy::redundant_clone)]
2291 left_cols.clone()
2292 };
2293 let matched = matched.project(&projection)?;
2294 let matched = RecordBatch::try_new(
2300 self.output_schema.clone(),
2301 Vec::from_iter(matched.columns().iter().cloned()),
2302 )?;
2303 batches.push(Ok(matched));
2304 }
2305 }
2306 }
2307 if self.params.insert_not_matched {
2308 let not_matched = arrow::compute::filter_record_batch(&batch, &left_only)?;
2309 let left_cols_with_id = left_cols
2310 .into_iter()
2311 .chain(row_addr_col)
2312 .collect::<Vec<_>>();
2313 let not_matched = not_matched.project(&left_cols_with_id)?;
2314 let not_matched = RecordBatch::try_new(
2316 self.output_schema.clone(),
2317 Vec::from_iter(not_matched.columns().iter().cloned()),
2318 )?;
2319
2320 merge_statistics.num_inserted_rows += not_matched.num_rows() as u64;
2321 batches.push(Ok(not_matched));
2322 }
2323 match self.params.delete_not_matched_by_source {
2324 WhenNotMatchedBySource::Delete => {
2325 let unmatched = arrow::compute::filter(batch.column(row_id_col), &right_only)?;
2326 merge_statistics.num_deleted_rows += unmatched.len() as u64;
2327 let row_ids = unmatched.as_primitive::<UInt64Type>();
2328 deleted_row_ids.extend(row_ids.values());
2329 }
2330 WhenNotMatchedBySource::DeleteIf(_) => {
2331 let target_data = batch.project(&right_cols_with_id)?;
2332 let unmatched = arrow::compute::filter_record_batch(&target_data, &right_only)?;
2333 let row_id_col = unmatched.num_columns() - 1;
2334 let to_delete = self.delete_expr.unwrap().evaluate(&unmatched)?;
2335
2336 match to_delete {
2337 ColumnarValue::Array(mask) => {
2338 let row_ids = arrow::compute::filter(
2339 unmatched.column(row_id_col),
2340 mask.as_boolean(),
2341 )?;
2342 let row_ids = row_ids.as_primitive::<UInt64Type>();
2343 merge_statistics.num_deleted_rows += row_ids.len() as u64;
2344 deleted_row_ids.extend(row_ids.values());
2345 }
2346 ColumnarValue::Scalar(scalar) => {
2347 if let ScalarValue::Boolean(Some(true)) = scalar {
2348 let row_ids = unmatched.column(row_id_col).as_primitive::<UInt64Type>();
2349 merge_statistics.num_deleted_rows += row_ids.len() as u64;
2350 deleted_row_ids.extend(row_ids.values());
2351 }
2352 }
2353 }
2354 }
2355 WhenNotMatchedBySource::Keep => {}
2356 }
2357
2358 Ok(stream::iter(batches))
2359 }
2360}
2361
2362#[cfg(test)]
2363mod tests {
2364 use super::*;
2365 use crate::dataset::scanner::ColumnOrdering;
2366 use crate::dataset::write::merge_insert::inserted_rows::{
2367 KeyExistenceFilter, KeyExistenceFilterBuilder, extract_key_value_from_batch,
2368 };
2369 use crate::index::vector::VectorIndexParams;
2370 use crate::io::commit::read_transaction_file;
2371 use crate::{
2372 dataset::{InsertBuilder, ReadParams, WriteMode, WriteParams, builder::DatasetBuilder},
2373 session::Session,
2374 utils::test::{
2375 DatagenExt, FragmentCount, FragmentRowCount, ThrottledStoreWrapper,
2376 assert_plan_node_equals, assert_string_matches,
2377 },
2378 };
2379 use arrow_array::builder::{ListBuilder, StringBuilder};
2380 use arrow_array::types::Float32Type;
2381 use arrow_array::{
2382 Array, FixedSizeListArray, Float32Array, Float64Array, Int32Array, Int64Array, ListArray,
2383 RecordBatchIterator, RecordBatchReader, StringArray, StructArray, UInt32Array,
2384 types::{Int32Type, UInt32Type},
2385 };
2386 use arrow_array::{RecordBatch, record_batch};
2387 use arrow_buffer::{OffsetBuffer, ScalarBuffer};
2388 use arrow_schema::{DataType, Field, Schema};
2389 use arrow_select::concat::concat_batches;
2390 use datafusion::common::Column;
2391 use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
2392 use futures::{FutureExt, StreamExt, TryStreamExt, future::try_join_all};
2393 use lance_arrow::FixedSizeListArrayExt;
2394 use lance_core::utils::tempfile::TempStrDir;
2395 use lance_datafusion::{datagen::DatafusionDatagenExt, utils::reader_to_stream};
2396 use lance_datagen::{BatchCount, Dimension, RowCount, Seed, array};
2397 use lance_index::IndexType;
2398 use lance_index::scalar::{FullTextSearchQuery, InvertedIndexParams, ScalarIndexParams};
2399 use lance_io::object_store::ObjectStoreParams;
2400 use lance_linalg::distance::MetricType;
2401 use mock_instant::thread_local::MockClock;
2402 use object_store::throttle::ThrottleConfig;
2403 use roaring::RoaringBitmap;
2404 use std::collections::HashMap;
2405 use tokio::sync::{Barrier, Notify};
2406
2407 fn assert_send<T: Send>(t: T) -> T {
2409 t
2410 }
2411
2412 async fn check_then_refresh_dataset(
2413 new_data: RecordBatch,
2414 mut job: MergeInsertJob,
2415 keys_from_left: &[u32],
2416 keys_from_right: &[u32],
2417 stats: &[u64],
2418 ) -> Arc<Dataset> {
2419 let mut dataset = (*job.dataset).clone();
2420 dataset.restore().await.unwrap();
2421 job.dataset = Arc::new(dataset);
2422
2423 let schema = new_data.schema();
2424 let new_reader = Box::new(RecordBatchIterator::new([Ok(new_data)], schema.clone()));
2425 let new_stream = reader_to_stream(new_reader);
2426
2427 let (merged_dataset, merge_stats) = job.execute(new_stream).boxed().await.unwrap();
2428
2429 let batches = merged_dataset
2430 .scan()
2431 .try_into_stream()
2432 .await
2433 .unwrap()
2434 .try_collect::<Vec<_>>()
2435 .await
2436 .unwrap();
2437
2438 let merged = concat_batches(&schema, &batches).unwrap();
2439
2440 let keyvals = merged
2441 .column(0)
2442 .as_primitive::<UInt32Type>()
2443 .values()
2444 .iter()
2445 .zip(
2446 merged
2447 .column(1)
2448 .as_primitive::<UInt32Type>()
2449 .values()
2450 .iter(),
2451 );
2452 let mut left_keys = keyvals
2453 .clone()
2454 .filter(|&(_, &val)| val == 1)
2455 .map(|(key, _)| key)
2456 .copied()
2457 .collect::<Vec<_>>();
2458 let mut right_keys = keyvals
2459 .clone()
2460 .filter(|&(_, &val)| val == 2)
2461 .map(|(key, _)| key)
2462 .copied()
2463 .collect::<Vec<_>>();
2464 left_keys.sort();
2465 right_keys.sort();
2466 assert_eq!(left_keys, keys_from_left);
2467 assert_eq!(right_keys, keys_from_right);
2468 assert_eq!(merge_stats.num_inserted_rows, stats[0]);
2469 assert_eq!(merge_stats.num_updated_rows, stats[1]);
2470 assert_eq!(merge_stats.num_deleted_rows, stats[2]);
2471
2472 merged_dataset
2473 }
2474
2475 fn create_test_schema() -> Arc<Schema> {
2476 Arc::new(Schema::new(vec![
2477 Field::new("key", DataType::UInt32, true),
2478 Field::new("value", DataType::UInt32, true),
2479 Field::new("filterme", DataType::Utf8, true),
2480 ]))
2481 }
2482
2483 fn create_new_batch(schema: Arc<Schema>) -> RecordBatch {
2484 RecordBatch::try_new(
2485 schema,
2486 vec![
2487 Arc::new(UInt32Array::from(vec![4, 5, 6, 7, 8, 9])),
2488 Arc::new(UInt32Array::from(vec![2, 2, 2, 2, 2, 2])),
2489 Arc::new(StringArray::from(vec!["A", "B", "C", "A", "B", "C"])),
2490 ],
2491 )
2492 .unwrap()
2493 }
2494
2495 async fn create_test_dataset(
2496 test_uri: &str,
2497 version: LanceFileVersion,
2498 enable_stable_row_ids: bool,
2499 ) -> Arc<Dataset> {
2500 let dataset = lance_datagen::gen_batch()
2501 .col("key", array::step_custom::<UInt32Type>(1, 1))
2502 .col("value", array::fill::<UInt32Type>(1u32))
2503 .col(
2504 "filterme",
2505 array::cycle_utf8_literals(&["A", "B", "A", "A", "B", "A"]),
2506 )
2507 .into_dataset_with_params(
2508 test_uri,
2509 FragmentCount(2),
2510 FragmentRowCount(3),
2511 Some(WriteParams {
2512 max_rows_per_file: 3,
2513 data_storage_version: Some(version),
2514 enable_stable_row_ids,
2515 ..Default::default()
2516 }),
2517 )
2518 .await
2519 .unwrap();
2520
2521 assert_eq!(2, dataset.get_fragments().len());
2522
2523 Arc::new(dataset)
2524 }
2525
2526 async fn get_row_ids_for_keys(dataset: &Dataset, keys: &[u32]) -> UInt64Array {
2527 let filter = format!(
2528 "key IN ({})",
2529 keys.iter()
2530 .map(|k| k.to_string())
2531 .collect::<Vec<_>>()
2532 .join(",")
2533 );
2534
2535 let batch = dataset
2536 .scan()
2537 .filter(&filter)
2538 .unwrap()
2539 .with_row_id()
2540 .order_by(Some(vec![ColumnOrdering::asc_nulls_first(
2541 "key".to_string(),
2542 )]))
2543 .unwrap()
2544 .try_into_batch()
2545 .await
2546 .unwrap();
2547
2548 batch
2549 .column_by_name(ROW_ID)
2550 .unwrap()
2551 .as_any()
2552 .downcast_ref::<UInt64Array>()
2553 .unwrap()
2554 .clone()
2555 }
2556
2557 fn create_delete_condition() -> Expr {
2558 Expr::gt(
2559 Expr::Column(Column::new_unqualified("key")),
2560 Expr::Literal(ScalarValue::UInt32(Some(1)), None),
2561 )
2562 }
2563
2564 struct MergeInsertTestBuilder {
2565 version: LanceFileVersion,
2566 enable_stable_row_ids: bool,
2567 test_keys: Vec<u32>,
2568 expected_left_keys: Vec<u32>,
2569 expected_right_keys: Vec<u32>,
2570 expected_stats: Vec<u64>,
2571 job_builder: Option<Box<dyn FnOnce(Arc<Dataset>) -> MergeInsertJob>>,
2572 }
2573
2574 impl MergeInsertTestBuilder {
2575 fn new() -> Self {
2576 Self {
2577 version: LanceFileVersion::default(),
2578 enable_stable_row_ids: false,
2579 test_keys: vec![],
2580 expected_left_keys: vec![],
2581 expected_right_keys: vec![],
2582 expected_stats: vec![],
2583 job_builder: None,
2584 }
2585 }
2586
2587 fn with_version(mut self, version: LanceFileVersion) -> Self {
2588 self.version = version;
2589 self
2590 }
2591
2592 fn with_stable_row_ids(mut self, enable: bool) -> Self {
2593 self.enable_stable_row_ids = enable;
2594 self
2595 }
2596
2597 fn with_test_keys(mut self, keys: &[u32]) -> Self {
2598 self.test_keys = keys.to_vec();
2599 self
2600 }
2601
2602 fn with_expected_left_keys(mut self, keys: &[u32]) -> Self {
2603 self.expected_left_keys = keys.to_vec();
2604 self
2605 }
2606
2607 fn with_expected_right_keys(mut self, keys: &[u32]) -> Self {
2608 self.expected_right_keys = keys.to_vec();
2609 self
2610 }
2611
2612 fn with_expected_stats(mut self, stats: &[u64]) -> Self {
2613 self.expected_stats = stats.to_vec();
2614 self
2615 }
2616
2617 fn with_job_builder<F>(mut self, builder: F) -> Self
2618 where
2619 F: FnOnce(Arc<Dataset>) -> MergeInsertJob + 'static,
2620 {
2621 self.job_builder = Some(Box::new(builder));
2622 self
2623 }
2624
2625 async fn run_test(self) {
2626 let schema = create_test_schema();
2627 let new_batch = create_new_batch(schema.clone());
2628 let test_uri = "memory://test.lance";
2629
2630 let ds = create_test_dataset(test_uri, self.version, self.enable_stable_row_ids).await;
2631 let row_ids_before = get_row_ids_for_keys(&ds, &self.test_keys).await;
2632
2633 let job_builder = self.job_builder.expect("job_builder must be set");
2634 let job = job_builder(ds);
2635 let ds = check_then_refresh_dataset(
2636 new_batch,
2637 job,
2638 &self.expected_left_keys,
2639 &self.expected_right_keys,
2640 &self.expected_stats,
2641 )
2642 .await;
2643
2644 let row_ids_after = get_row_ids_for_keys(&ds, &self.test_keys).await;
2645
2646 if self.enable_stable_row_ids {
2647 assert_eq!(row_ids_before, row_ids_after);
2648 } else {
2649 assert_ne!(row_ids_before, row_ids_after);
2650 }
2651 }
2652 }
2653
2654 #[tokio::test]
2655 async fn test_merge_insert_requires_on_or_primary_key() {
2656 let test_uri = "memory://merge_insert_requires_keys";
2657
2658 let ds = create_test_dataset(test_uri, LanceFileVersion::V2_0, false).await;
2659
2660 let err = MergeInsertBuilder::try_new(ds, Vec::new()).unwrap_err();
2661 if let crate::Error::InvalidInput { source, .. } = err {
2662 let msg = source.to_string();
2663 assert!(
2664 msg.contains("requires join keys") && msg.contains("primary key"),
2665 "unexpected error message: {}",
2666 msg
2667 );
2668 } else {
2669 panic!("expected InvalidInput error");
2670 }
2671 }
2672
2673 #[tokio::test]
2674 async fn test_merge_insert_defaults_to_unenforced_primary_key() {
2675 let id_field = Field::new("id", DataType::Int32, false).with_metadata(
2677 [(
2678 "lance-schema:unenforced-primary-key".to_string(),
2679 "true".to_string(),
2680 )]
2681 .into(),
2682 );
2683 let value_field = Field::new("value", DataType::Int32, false);
2684 let schema = Arc::new(Schema::new(vec![id_field, value_field]));
2685
2686 let initial_batch = RecordBatch::try_new(
2687 schema.clone(),
2688 vec![
2689 Arc::new(Int32Array::from(vec![1, 2, 3])),
2690 Arc::new(Int32Array::from(vec![10, 20, 30])),
2691 ],
2692 )
2693 .unwrap();
2694
2695 let reader = RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone());
2696 let dataset = Dataset::write(
2697 reader,
2698 "memory://merge_insert_pk_default",
2699 Some(WriteParams {
2700 data_storage_version: Some(LanceFileVersion::V2_0),
2701 ..Default::default()
2702 }),
2703 )
2704 .await
2705 .unwrap();
2706 let dataset = Arc::new(dataset);
2707
2708 let new_batch = RecordBatch::try_new(
2710 schema.clone(),
2711 vec![
2712 Arc::new(Int32Array::from(vec![2, 3, 4])),
2713 Arc::new(Int32Array::from(vec![200, 300, 400])),
2714 ],
2715 )
2716 .unwrap();
2717
2718 let mut builder = MergeInsertBuilder::try_new(dataset.clone(), Vec::new()).unwrap();
2719 builder
2720 .when_matched(WhenMatched::UpdateAll)
2721 .when_not_matched(WhenNotMatched::InsertAll);
2722 let job = builder.try_build().unwrap();
2723
2724 let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
2725 let new_stream = reader_to_stream(new_reader);
2726
2727 let (updated_dataset, stats) = job.execute(new_stream).await.unwrap();
2728
2729 assert_eq!(stats.num_inserted_rows, 1);
2730 assert_eq!(stats.num_updated_rows, 2);
2731 assert_eq!(stats.num_deleted_rows, 0);
2732
2733 let result_batch = updated_dataset.scan().try_into_batch().await.unwrap();
2734 let ids = result_batch
2735 .column_by_name("id")
2736 .unwrap()
2737 .as_primitive::<Int32Type>();
2738 let values = result_batch
2739 .column_by_name("value")
2740 .unwrap()
2741 .as_primitive::<Int32Type>();
2742
2743 let mut pairs = (0..ids.len())
2744 .map(|i| (ids.value(i), values.value(i)))
2745 .collect::<Vec<_>>();
2746 pairs.sort_unstable();
2747
2748 assert_eq!(pairs, vec![(1, 10), (2, 200), (3, 300), (4, 400)]);
2749 }
2750
2751 #[rstest::rstest]
2752 #[tokio::test]
2753 async fn test_basic_merge(
2754 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
2755 ) {
2756 let schema = create_test_schema();
2757 let new_batch = create_new_batch(schema.clone());
2758
2759 let test_uri = "memory://test.lance";
2760
2761 let ds = create_test_dataset(test_uri, version, false).await;
2762
2763 assert!(MergeInsertBuilder::try_new(ds.clone(), vec![]).is_err());
2765
2766 let keys = vec!["key".to_string()];
2767 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2769 .unwrap()
2770 .try_build()
2771 .unwrap();
2772 check_then_refresh_dataset(
2773 new_batch.clone(),
2774 job,
2775 &[1, 2, 3, 4, 5, 6],
2776 &[7, 8, 9],
2777 &[3, 0, 0],
2778 )
2779 .await;
2780
2781 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2783 .unwrap()
2784 .when_matched(WhenMatched::UpdateAll)
2785 .try_build()
2786 .unwrap();
2787 check_then_refresh_dataset(
2788 new_batch.clone(),
2789 job,
2790 &[1, 2, 3],
2791 &[4, 5, 6, 7, 8, 9],
2792 &[3, 3, 0],
2793 )
2794 .await;
2795
2796 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2798 .unwrap()
2799 .when_matched(
2800 WhenMatched::update_if(&ds, "source.filterme != target.filterme").unwrap(),
2801 )
2802 .try_build()
2803 .unwrap();
2804 check_then_refresh_dataset(
2805 new_batch.clone(),
2806 job,
2807 &[1, 2, 3, 4, 5],
2808 &[6, 7, 8, 9],
2809 &[3, 1, 0],
2810 )
2811 .await;
2812
2813 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2815 .unwrap()
2816 .when_not_matched(WhenNotMatched::DoNothing)
2817 .when_matched(WhenMatched::update_if(&ds, "target.filterme = 'z'").unwrap())
2818 .try_build()
2819 .unwrap();
2820 check_then_refresh_dataset(new_batch.clone(), job, &[1, 2, 3, 4, 5, 6], &[], &[0, 0, 0])
2821 .await;
2822
2823 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2825 .unwrap()
2826 .when_matched(WhenMatched::UpdateAll)
2827 .when_not_matched(WhenNotMatched::DoNothing)
2828 .try_build()
2829 .unwrap();
2830 check_then_refresh_dataset(new_batch.clone(), job, &[1, 2, 3], &[4, 5, 6], &[0, 3, 0])
2831 .await;
2832
2833 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2835 .unwrap()
2836 .when_matched(
2837 WhenMatched::update_if(&ds, "source.filterme == target.filterme").unwrap(),
2838 )
2839 .when_not_matched(WhenNotMatched::DoNothing)
2840 .try_build()
2841 .unwrap();
2842 check_then_refresh_dataset(new_batch.clone(), job, &[1, 2, 3, 6], &[4, 5], &[0, 2, 0])
2843 .await;
2844
2845 assert!(
2847 MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2848 .unwrap()
2849 .when_not_matched(WhenNotMatched::DoNothing)
2850 .try_build()
2851 .is_err()
2852 );
2853
2854 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2856 .unwrap()
2857 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
2858 .try_build()
2859 .unwrap();
2860 check_then_refresh_dataset(new_batch.clone(), job, &[4, 5, 6], &[7, 8, 9], &[3, 0, 3])
2861 .await;
2862
2863 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2865 .unwrap()
2866 .when_matched(WhenMatched::UpdateAll)
2867 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
2868 .try_build()
2869 .unwrap();
2870 check_then_refresh_dataset(new_batch.clone(), job, &[], &[4, 5, 6, 7, 8, 9], &[3, 3, 3])
2871 .await;
2872
2873 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2875 .unwrap()
2876 .when_matched(
2877 WhenMatched::update_if(&ds, "source.filterme != target.filterme").unwrap(),
2878 )
2879 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
2880 .try_build()
2881 .unwrap();
2882 check_then_refresh_dataset(new_batch.clone(), job, &[4, 5], &[6, 7, 8, 9], &[3, 1, 3])
2883 .await;
2884
2885 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2887 .unwrap()
2888 .when_matched(WhenMatched::UpdateAll)
2889 .when_not_matched(WhenNotMatched::DoNothing)
2890 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
2891 .try_build()
2892 .unwrap();
2893 check_then_refresh_dataset(new_batch.clone(), job, &[], &[4, 5, 6], &[0, 3, 3]).await;
2894
2895 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2897 .unwrap()
2898 .when_not_matched(WhenNotMatched::DoNothing)
2899 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
2900 .try_build()
2901 .unwrap();
2902 check_then_refresh_dataset(new_batch.clone(), job, &[4, 5, 6], &[], &[0, 0, 3]).await;
2903
2904 let condition = create_delete_condition();
2906 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2908 .unwrap()
2909 .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
2910 .try_build()
2911 .unwrap();
2912 check_then_refresh_dataset(
2913 new_batch.clone(),
2914 job,
2915 &[1, 4, 5, 6],
2916 &[7, 8, 9],
2917 &[3, 0, 2],
2918 )
2919 .await;
2920
2921 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2923 .unwrap()
2924 .when_matched(WhenMatched::UpdateAll)
2925 .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
2926 .try_build()
2927 .unwrap();
2928 check_then_refresh_dataset(
2929 new_batch.clone(),
2930 job,
2931 &[1],
2932 &[4, 5, 6, 7, 8, 9],
2933 &[3, 3, 2],
2934 )
2935 .await;
2936
2937 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2939 .unwrap()
2940 .when_matched(
2941 WhenMatched::update_if(&ds, "source.filterme != target.filterme").unwrap(),
2942 )
2943 .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
2944 .try_build()
2945 .unwrap();
2946 check_then_refresh_dataset(
2947 new_batch.clone(),
2948 job,
2949 &[1, 4, 5],
2950 &[6, 7, 8, 9],
2951 &[3, 1, 2],
2952 )
2953 .await;
2954
2955 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2957 .unwrap()
2958 .when_matched(WhenMatched::UpdateAll)
2959 .when_not_matched(WhenNotMatched::DoNothing)
2960 .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
2961 .try_build()
2962 .unwrap();
2963 check_then_refresh_dataset(new_batch.clone(), job, &[1], &[4, 5, 6], &[0, 3, 2]).await;
2964
2965 let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
2967 .unwrap()
2968 .when_not_matched(WhenNotMatched::DoNothing)
2969 .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone()))
2970 .try_build()
2971 .unwrap();
2972 check_then_refresh_dataset(new_batch.clone(), job, &[1, 4, 5, 6], &[], &[0, 0, 2]).await;
2973 }
2974
2975 #[rstest::rstest]
2976 #[tokio::test]
2977 async fn test_upsert_and_delete_all_with_stable_row_id(
2978 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
2979 #[values(true, false)] enable_stable_row_ids: bool,
2980 ) {
2981 MergeInsertTestBuilder::new()
2982 .with_version(version)
2983 .with_stable_row_ids(enable_stable_row_ids)
2984 .with_test_keys(&[4, 5, 6])
2985 .with_expected_left_keys(&[])
2986 .with_expected_right_keys(&[4, 5, 6, 7, 8, 9])
2987 .with_expected_stats(&[3, 3, 3])
2988 .with_job_builder(|ds| {
2989 MergeInsertBuilder::try_new(ds, vec!["key".to_string()])
2990 .unwrap()
2991 .when_matched(WhenMatched::UpdateAll)
2992 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
2993 .try_build()
2994 .unwrap()
2995 })
2996 .run_test()
2997 .await;
2998 }
2999
3000 #[rstest::rstest]
3001 #[tokio::test]
3002 async fn test_upsert_only_with_stable_row_id(
3003 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
3004 #[values(true, false)] enable_stable_row_ids: bool,
3005 ) {
3006 MergeInsertTestBuilder::new()
3007 .with_version(version)
3008 .with_stable_row_ids(enable_stable_row_ids)
3009 .with_test_keys(&[4, 5, 6])
3010 .with_expected_left_keys(&[1, 2, 3])
3011 .with_expected_right_keys(&[4, 5, 6, 7, 8, 9])
3012 .with_expected_stats(&[3, 3, 0])
3013 .with_job_builder(|ds| {
3014 MergeInsertBuilder::try_new(ds, vec!["key".to_string()])
3015 .unwrap()
3016 .when_matched(WhenMatched::UpdateAll)
3017 .try_build()
3018 .unwrap()
3019 })
3020 .run_test()
3021 .await;
3022 }
3023
3024 #[rstest::rstest]
3025 #[tokio::test]
3026 async fn test_conditional_update_with_stable_row_id(
3027 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
3028 #[values(true, false)] enable_stable_row_ids: bool,
3029 ) {
3030 MergeInsertTestBuilder::new()
3031 .with_version(version)
3032 .with_stable_row_ids(enable_stable_row_ids)
3033 .with_test_keys(&[6])
3034 .with_expected_left_keys(&[1, 2, 3, 4, 5])
3035 .with_expected_right_keys(&[6, 7, 8, 9])
3036 .with_expected_stats(&[3, 1, 0])
3037 .with_job_builder(|ds| {
3038 let keys = vec!["key".to_string()];
3039 MergeInsertBuilder::try_new(ds.clone(), keys)
3040 .unwrap()
3041 .when_matched(
3042 WhenMatched::update_if(&ds, "source.filterme != target.filterme").unwrap(),
3043 )
3044 .try_build()
3045 .unwrap()
3046 })
3047 .run_test()
3048 .await;
3049 }
3050
3051 #[rstest::rstest]
3052 #[tokio::test]
3053 async fn test_update_only_with_stable_row_id(
3054 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
3055 #[values(true, false)] enable_stable_row_ids: bool,
3056 ) {
3057 MergeInsertTestBuilder::new()
3058 .with_version(version)
3059 .with_stable_row_ids(enable_stable_row_ids)
3060 .with_test_keys(&[4, 5, 6])
3061 .with_expected_left_keys(&[1, 2, 3])
3062 .with_expected_right_keys(&[4, 5, 6])
3063 .with_expected_stats(&[0, 3, 0])
3064 .with_job_builder(|ds| {
3065 let keys = vec!["key".to_string()];
3066 MergeInsertBuilder::try_new(ds, keys)
3067 .unwrap()
3068 .when_matched(WhenMatched::UpdateAll)
3069 .when_not_matched(WhenNotMatched::DoNothing)
3070 .try_build()
3071 .unwrap()
3072 })
3073 .run_test()
3074 .await;
3075 }
3076
3077 #[rstest::rstest]
3078 #[tokio::test]
3079 async fn test_upsert_with_conditional_delete_and_stable_row_id(
3080 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
3081 #[values(true, false)] enable_stable_row_ids: bool,
3082 ) {
3083 MergeInsertTestBuilder::new()
3084 .with_version(version)
3085 .with_stable_row_ids(enable_stable_row_ids)
3086 .with_test_keys(&[1, 4, 5, 6])
3087 .with_expected_left_keys(&[1])
3088 .with_expected_right_keys(&[4, 5, 6, 7, 8, 9])
3089 .with_expected_stats(&[3, 3, 2])
3090 .with_job_builder(|ds| {
3091 let keys = vec!["key".to_string()];
3092 let condition = create_delete_condition();
3093 MergeInsertBuilder::try_new(ds, keys)
3094 .unwrap()
3095 .when_matched(WhenMatched::UpdateAll)
3096 .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition))
3097 .try_build()
3098 .unwrap()
3099 })
3100 .run_test()
3101 .await;
3102 }
3103
3104 #[rstest::rstest]
3105 #[tokio::test]
3106 async fn test_multiple_merge_insert_stable_row_id(
3107 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
3108 #[values(true, false)] enable_stable_row_ids: bool,
3109 ) {
3110 let schema = create_test_schema();
3111 let test_uri = "memory://test_multiple_merge.lance";
3112
3113 let ds = create_test_dataset(test_uri, version, enable_stable_row_ids).await;
3114
3115 let target_key = 2u32;
3116 let target_keys = vec![target_key];
3117
3118 let initial_row_ids = get_row_ids_for_keys(&ds, &target_keys).await;
3119 let initial_row_id = initial_row_ids.value(0);
3120
3121 let mut current_ds = ds;
3122
3123 for iteration in 1..=3 {
3124 let new_value = 1000u32 + iteration * 10;
3125 let new_batch = RecordBatch::try_new(
3126 schema.clone(),
3127 vec![
3128 Arc::new(UInt32Array::from(vec![target_key])), Arc::new(UInt32Array::from(vec![new_value])), Arc::new(StringArray::from(vec![format!("iteration_{}", iteration)])), ],
3132 )
3133 .unwrap();
3134
3135 let job = MergeInsertBuilder::try_new(current_ds.clone(), vec!["key".to_string()])
3136 .unwrap()
3137 .when_matched(WhenMatched::UpdateAll)
3138 .when_not_matched(WhenNotMatched::DoNothing)
3139 .try_build()
3140 .unwrap();
3141
3142 let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
3143 let new_stream = reader_to_stream(new_reader);
3144 let (updated_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
3145
3146 assert_eq!(
3147 merge_stats.num_updated_rows, 1,
3148 "Iteration {}: Expected 1 updated row",
3149 iteration
3150 );
3151 assert_eq!(
3152 merge_stats.num_inserted_rows, 0,
3153 "Iteration {}: Expected 0 inserted rows",
3154 iteration
3155 );
3156 assert_eq!(
3157 merge_stats.num_deleted_rows, 0,
3158 "Iteration {}: Expected 0 deleted rows",
3159 iteration
3160 );
3161
3162 let updated_row_ids = get_row_ids_for_keys(&updated_dataset, &target_keys).await;
3163 let updated_row_id = updated_row_ids.value(0);
3164
3165 let updated_batch = updated_dataset
3166 .scan()
3167 .filter(&format!("key = {}", target_key))
3168 .unwrap()
3169 .try_into_batch()
3170 .await
3171 .unwrap();
3172
3173 let value_col = updated_batch
3174 .column_by_name("value")
3175 .unwrap()
3176 .as_any()
3177 .downcast_ref::<UInt32Array>()
3178 .unwrap();
3179 let filterme_col = updated_batch
3180 .column_by_name("filterme")
3181 .unwrap()
3182 .as_any()
3183 .downcast_ref::<StringArray>()
3184 .unwrap();
3185
3186 assert_eq!(
3187 value_col.value(0),
3188 new_value,
3189 "Iteration {}: Value should be updated to {}",
3190 iteration,
3191 new_value
3192 );
3193 assert_eq!(filterme_col.value(0), format!("iteration_{}", iteration));
3194
3195 if enable_stable_row_ids {
3196 assert_eq!(
3197 updated_row_id, initial_row_id,
3198 "Iteration {}: Row ID should remain stable across merge inserts when stable_row_ids is enabled. Initial: {}, Current: {}",
3199 iteration, initial_row_id, updated_row_id
3200 );
3201 }
3202
3203 current_ds = updated_dataset;
3204 }
3205
3206 let final_batch = current_ds
3207 .scan()
3208 .filter(&format!("key = {}", target_key))
3209 .unwrap()
3210 .try_into_batch()
3211 .await
3212 .unwrap();
3213
3214 assert_eq!(
3215 final_batch.num_rows(),
3216 1,
3217 "Should have exactly one row for the target key"
3218 );
3219
3220 let final_value = final_batch
3221 .column_by_name("value")
3222 .unwrap()
3223 .as_any()
3224 .downcast_ref::<UInt32Array>()
3225 .unwrap()
3226 .value(0);
3227 let final_filterme = final_batch
3228 .column_by_name("filterme")
3229 .unwrap()
3230 .as_any()
3231 .downcast_ref::<StringArray>()
3232 .unwrap()
3233 .value(0);
3234
3235 assert_eq!(
3236 final_value, 1030u32,
3237 "Final value should be from last iteration"
3238 );
3239 assert_eq!(
3240 final_filterme, "iteration_3",
3241 "Final filterme should be from last iteration"
3242 );
3243 }
3244
3245 #[rstest::rstest]
3246 #[tokio::test]
3247 async fn test_row_id_stability_across_update_and_merge_insert(
3248 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
3249 #[values(true, false)] enable_stable_row_ids: bool,
3250 ) {
3251 let schema = create_test_schema();
3252 let test_uri = "memory://test_row_id_stability.lance";
3253
3254 let mut dataset = create_test_dataset(test_uri, version, enable_stable_row_ids).await;
3255
3256 let target_key = 2u32;
3257 let target_keys = vec![target_key];
3258
3259 let initial_row_ids = get_row_ids_for_keys(&dataset, &target_keys).await;
3260 let initial_row_id = initial_row_ids.value(0);
3261
3262 let initial_batch = dataset
3263 .scan()
3264 .filter(&format!("key = {}", target_key))
3265 .unwrap()
3266 .with_row_id()
3267 .try_into_batch()
3268 .await
3269 .unwrap();
3270
3271 let initial_value = initial_batch
3272 .column_by_name("value")
3273 .unwrap()
3274 .as_primitive::<UInt32Type>()
3275 .value(0);
3276
3277 let update_result = crate::dataset::UpdateBuilder::new(Arc::new((*dataset).clone()))
3278 .update_where(&format!("key = {}", target_key))
3279 .unwrap()
3280 .set("value", "value + 100")
3281 .unwrap()
3282 .build()
3283 .unwrap()
3284 .execute()
3285 .await
3286 .unwrap();
3287
3288 dataset = update_result.new_dataset.clone();
3289
3290 let after_update_row_ids = get_row_ids_for_keys(&dataset, &target_keys).await;
3291 let after_update_row_id = after_update_row_ids.value(0);
3292
3293 let after_update_batch = dataset
3294 .scan()
3295 .filter(&format!("key = {}", target_key))
3296 .unwrap()
3297 .with_row_id()
3298 .try_into_batch()
3299 .await
3300 .unwrap();
3301
3302 let after_update_value = after_update_batch
3303 .column_by_name("value")
3304 .unwrap()
3305 .as_primitive::<UInt32Type>()
3306 .value(0);
3307
3308 if enable_stable_row_ids {
3309 assert_eq!(
3310 initial_row_id, after_update_row_id,
3311 "Row ID should remain stable after update"
3312 );
3313 } else {
3314 assert_ne!(
3315 initial_row_id, after_update_row_id,
3316 "Row ID should change after update when stable row IDs are disabled"
3317 );
3318 }
3319 assert_eq!(
3320 after_update_value,
3321 initial_value + 100,
3322 "Value should be updated correctly"
3323 );
3324
3325 let merge_new_value = 500u32;
3326 let new_batch = RecordBatch::try_new(
3327 schema.clone(),
3328 vec![
3329 Arc::new(UInt32Array::from(vec![target_key])),
3330 Arc::new(UInt32Array::from(vec![merge_new_value])),
3331 Arc::new(StringArray::from(vec!["UPDATED"])),
3332 ],
3333 )
3334 .unwrap();
3335
3336 let job = MergeInsertBuilder::try_new(dataset.clone(), vec!["key".to_string()])
3337 .unwrap()
3338 .when_matched(WhenMatched::UpdateAll)
3339 .try_build()
3340 .unwrap();
3341
3342 let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
3343 let new_stream = reader_to_stream(new_reader);
3344
3345 let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
3346
3347 let after_merge_row_ids = get_row_ids_for_keys(&merged_dataset, &target_keys).await;
3348 let after_merge_row_id = after_merge_row_ids.value(0);
3349
3350 let after_merge_batch = merged_dataset
3351 .scan()
3352 .filter(&format!("key = {}", target_key))
3353 .unwrap()
3354 .with_row_id()
3355 .try_into_batch()
3356 .await
3357 .unwrap();
3358
3359 let after_merge_value = after_merge_batch
3360 .column_by_name("value")
3361 .unwrap()
3362 .as_primitive::<UInt32Type>()
3363 .value(0);
3364
3365 let after_merge_filterme = after_merge_batch
3366 .column_by_name("filterme")
3367 .unwrap()
3368 .as_any()
3369 .downcast_ref::<StringArray>()
3370 .unwrap()
3371 .value(0);
3372
3373 if enable_stable_row_ids {
3374 assert_eq!(
3375 initial_row_id, after_merge_row_id,
3376 "Row ID should remain stable after merge insert"
3377 );
3378 assert_eq!(
3379 after_update_row_id, after_merge_row_id,
3380 "Row ID should remain the same across update and merge insert"
3381 );
3382 } else {
3383 assert_ne!(
3384 after_update_row_id, after_merge_row_id,
3385 "Row ID should change after merge insert when stable row IDs are disabled"
3386 );
3387 }
3388
3389 assert_eq!(
3390 after_merge_value, merge_new_value,
3391 "Value should be updated by merge insert"
3392 );
3393 assert_eq!(
3394 after_merge_filterme, "UPDATED",
3395 "Filterme should be updated by merge insert"
3396 );
3397
3398 assert_eq!(
3399 merge_stats.num_updated_rows, 1,
3400 "Should update exactly 1 row"
3401 );
3402 assert_eq!(
3403 merge_stats.num_inserted_rows, 0,
3404 "Should not insert any new rows"
3405 );
3406 assert_eq!(
3407 merge_stats.num_deleted_rows, 0,
3408 "Should not delete any rows"
3409 );
3410
3411 if enable_stable_row_ids {
3412 assert_eq!(
3413 initial_row_id, after_merge_row_id,
3414 "Row ID should remain stable throughout the entire process of update and merge insert"
3415 );
3416 }
3417 }
3418
3419 #[tokio::test]
3420 async fn test_indexed_merge_insert() {
3421 let test_dir = TempStrDir::default();
3422 let test_uri = &test_dir;
3423
3424 let data = lance_datagen::gen_batch()
3425 .with_seed(Seed::from(1))
3426 .col("value", array::step::<UInt32Type>())
3427 .col("key", array::rand_pseudo_uuid_hex());
3428 let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(32));
3429 let schema = data.schema();
3430
3431 let mut ds = Dataset::write(data, test_uri, None).await.unwrap();
3433 let index_params = ScalarIndexParams::default();
3434 ds.create_index(&["key"], IndexType::Scalar, None, &index_params, false)
3435 .await
3436 .unwrap();
3437
3438 let data = lance_datagen::gen_batch()
3440 .with_seed(Seed::from(2))
3441 .col("value", array::step::<UInt32Type>())
3442 .col("key", array::rand_pseudo_uuid_hex());
3443 let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(8));
3444 let ds = Dataset::write(
3445 data,
3446 test_uri,
3447 Some(WriteParams {
3448 mode: WriteMode::Append,
3449 ..Default::default()
3450 }),
3451 )
3452 .await
3453 .unwrap();
3454
3455 let ds = Arc::new(ds);
3456
3457 let just_index_col = Schema::new(vec![Field::new("key", DataType::Utf8, false)]);
3458
3459 let some_indices = ds
3461 .sample(2048, &(&just_index_col).try_into().unwrap(), None)
3462 .await
3463 .unwrap();
3464 let some_indices = some_indices.column(0).clone();
3465 let some_vals = lance_datagen::gen_batch()
3466 .anon_col(array::fill::<UInt32Type>(9999999))
3467 .into_batch_rows(RowCount::from(2048))
3468 .unwrap();
3469 let some_vals = some_vals.column(0).clone();
3470 let source_batch =
3471 RecordBatch::try_new(schema.clone(), vec![some_vals, some_indices]).unwrap();
3472 let source_batches = vec![
3474 source_batch.slice(0, 512),
3475 source_batch.slice(512, 512),
3476 source_batch.slice(1024, 512),
3477 source_batch.slice(1536, 512),
3478 ];
3479 let source = Box::new(RecordBatchIterator::new(
3480 source_batches.clone().into_iter().map(Ok),
3481 schema.clone(),
3482 ));
3483
3484 let (ds, _) = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
3486 .unwrap()
3487 .when_not_matched(WhenNotMatched::DoNothing)
3488 .when_matched(WhenMatched::UpdateAll)
3489 .try_build()
3490 .unwrap()
3491 .execute_reader(source)
3492 .await
3493 .unwrap();
3494
3495 let updated = ds
3497 .count_rows(Some("value = 9999999".to_string()))
3498 .await
3499 .unwrap();
3500 assert_eq!(updated, 2048);
3501
3502 let source = Box::new(RecordBatchIterator::new(
3504 source_batches.clone().into_iter().map(Ok),
3505 schema.clone(),
3506 ));
3507 let (ds, _) = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
3509 .unwrap()
3510 .when_not_matched(WhenNotMatched::DoNothing)
3511 .when_matched(WhenMatched::UpdateAll)
3512 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
3513 .try_build()
3514 .unwrap()
3515 .execute_reader(source)
3516 .await
3517 .unwrap();
3518
3519 assert_eq!(ds.count_rows(None).await.unwrap(), 2048);
3521
3522 let source = Box::new(RecordBatchIterator::new(
3523 source_batches.clone().into_iter().map(Ok),
3524 schema.clone(),
3525 ));
3526 let (ds, _) = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
3529 .unwrap()
3530 .when_not_matched(WhenNotMatched::DoNothing)
3531 .when_matched(WhenMatched::UpdateAll)
3532 .try_build()
3533 .unwrap()
3534 .execute_reader(source)
3535 .await
3536 .unwrap();
3537
3538 assert_eq!(ds.count_rows(None).await.unwrap(), 2048);
3539 }
3540
3541 mod subcols {
3542 use super::*;
3543 use rstest::rstest;
3544
3545 struct Fixtures {
3546 ds: Arc<Dataset>,
3547 new_data: RecordBatch,
3548 }
3549
3550 async fn setup(scalar_index: bool) -> Fixtures {
3551 let data = lance_datagen::gen_batch()
3552 .with_seed(Seed::from(1))
3553 .col("other", array::rand_utf8(4.into(), false))
3554 .col("value", array::step::<UInt32Type>())
3555 .col("key", array::rand_pseudo_uuid_hex());
3556 let batch = data.into_batch_rows(RowCount::from(1024 + 2)).unwrap();
3557 let batch1 = batch.slice(0, 512);
3558 let batch2 = batch.slice(512, 512);
3559 let batch3 = batch.slice(1024, 2);
3560 let schema = batch.schema();
3561
3562 let reader = Box::new(RecordBatchIterator::new(
3563 [Ok(batch1.clone())],
3564 schema.clone(),
3565 ));
3566 let write_params = WriteParams {
3567 max_rows_per_file: 256,
3568 max_rows_per_group: 32, ..Default::default()
3570 };
3571 let mut ds = Dataset::write(reader, "memory://", Some(write_params.clone()))
3572 .await
3573 .unwrap();
3574
3575 if scalar_index {
3576 let index_params = ScalarIndexParams::default();
3577 ds.create_index(&["key"], IndexType::Scalar, None, &index_params, false)
3578 .await
3579 .unwrap();
3580 }
3581
3582 let reader = Box::new(RecordBatchIterator::new(
3584 [Ok(batch2.clone())],
3585 batch2.schema(),
3586 ));
3587 ds.append(reader, Some(write_params)).await.unwrap();
3588
3589 let ds = Arc::new(ds);
3590
3591 let update_schema = Arc::new(schema.project(&[2, 1]).unwrap());
3593 let indices: Int64Array = (256..512).chain(600..612).chain([712, 715]).collect();
3595 let keys = arrow::compute::take(batch["key"].as_ref(), &indices, None).unwrap();
3596 let keys = arrow::compute::concat(&[&keys, &batch3["key"]]).unwrap();
3597 let num_rows = keys.len();
3598 let new_data = RecordBatch::try_new(
3599 update_schema,
3600 vec![
3601 keys,
3602 Arc::new((1024..(1024 + num_rows as u32)).collect::<UInt32Array>()),
3603 ],
3604 )
3605 .unwrap();
3606
3607 Fixtures { ds, new_data }
3608 }
3609
3610 #[tokio::test]
3611 async fn test_delete_not_matched_by_source_on_v2_subcols() {
3612 let Fixtures { ds, new_data } = Box::pin(setup(false)).await;
3619
3620 let rows_before = ds.count_rows(None).await.unwrap() as u64;
3621
3622 let reader = Box::new(RecordBatchIterator::new(
3623 [Ok(new_data.clone())],
3624 new_data.schema(),
3625 ));
3626
3627 let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
3628 .unwrap()
3629 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
3630 .when_matched(WhenMatched::UpdateAll)
3631 .when_not_matched(WhenNotMatched::DoNothing)
3632 .try_build()
3633 .unwrap();
3634 let (updated_ds, stats) = assert_send(job.execute_reader(reader))
3638 .await
3639 .expect("partial-schema + delete-by-source should succeed on v2");
3640
3641 assert_eq!(stats.num_updated_rows, 270);
3646 assert_eq!(stats.num_inserted_rows, 0);
3647 assert_eq!(stats.num_deleted_rows, rows_before - 270);
3648 assert_eq!(
3649 updated_ds.count_rows(None).await.unwrap() as u64,
3650 270,
3651 "only the 270 updated rows should remain after the delete-by-source"
3652 );
3653 }
3654
3655 #[tokio::test]
3656 async fn test_errors_on_bad_schema() {
3657 let Fixtures { ds, new_data } = Box::pin(setup(false)).await;
3658
3659 let bad_schema = Arc::new(Schema::new(vec![
3661 Field::new("wrong_key", DataType::Utf8, false),
3662 Field::new("wrong_value", DataType::UInt32, false),
3663 ]));
3664
3665 let bad_batch =
3667 RecordBatch::try_new(bad_schema.clone(), new_data.columns().to_vec()).unwrap();
3668 let reader = Box::new(RecordBatchIterator::new([Ok(bad_batch)], bad_schema));
3669
3670 let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
3671 .unwrap()
3672 .when_matched(WhenMatched::UpdateAll)
3673 .when_not_matched(WhenNotMatched::DoNothing)
3674 .try_build()
3675 .unwrap();
3676 let res = job.execute_reader(reader).await;
3677 assert!(
3678 matches!(
3679 &res,
3680 &Err(Error::SchemaMismatch { ref difference, .. })
3681 if difference.clone().contains("fields did not match")
3682 ),
3683 "Expected SchemaMismatch error, got: {:?}",
3684 res
3685 );
3686 }
3687
3688 #[rstest]
3689 #[tokio::test]
3690 async fn test_merge_insert_subcols(
3691 #[values(false, true)] scalar_index: bool,
3692 #[values(false, true)] insert: bool,
3693 ) {
3694 let Fixtures { ds, new_data } = Box::pin(setup(scalar_index)).await;
3695 let reader = Box::new(RecordBatchIterator::new(
3696 [Ok(new_data.clone())],
3697 new_data.schema(),
3698 ));
3699 let fragments_before = ds
3700 .get_fragments()
3701 .iter()
3702 .map(|f| f.metadata().clone())
3703 .collect::<Vec<_>>();
3704 let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
3705 .unwrap()
3706 .when_matched(WhenMatched::UpdateAll)
3707 .when_not_matched(if insert {
3708 WhenNotMatched::InsertAll
3709 } else {
3710 WhenNotMatched::DoNothing
3711 })
3712 .try_build()
3713 .unwrap();
3714
3715 let (ds, stats) = job.execute_reader(reader).await.unwrap();
3716
3717 let fragments_after = ds
3718 .get_fragments()
3719 .iter()
3720 .map(|f| f.metadata().clone())
3721 .collect::<Vec<_>>();
3722
3723 assert_eq!(stats.num_updated_rows, (new_data.num_rows() - 2) as u64);
3726 assert_eq!(stats.num_deleted_rows, 0);
3727 if insert {
3728 assert_eq!(stats.num_inserted_rows, 2);
3729 } else {
3730 assert_eq!(stats.num_inserted_rows, 0);
3731 }
3732
3733 if scalar_index {
3734 assert_eq!(
3741 fragments_before.iter().map(|f| f.id).collect::<Vec<_>>(),
3742 fragments_after
3743 .iter()
3744 .take(fragments_before.len())
3745 .map(|f| f.id)
3746 .collect::<Vec<_>>()
3747 );
3748 assert_eq!(fragments_before[0], fragments_after[0]);
3749 assert_ne!(fragments_before[1], fragments_after[1]);
3750 assert_ne!(fragments_before[2], fragments_after[2]);
3751 assert_eq!(fragments_before[3], fragments_after[3]);
3752
3753 let has_added_files = |frag: &Fragment| {
3754 assert_eq!(frag.files.len(), 2);
3755 let data_files = &frag.files;
3756 assert_eq!(data_files[0].fields.as_ref(), &[0, -2, -2]);
3759 assert_eq!(data_files[1].fields.as_ref(), &[2, 1]);
3760 };
3761 has_added_files(&fragments_after[1]);
3762 has_added_files(&fragments_after[2]);
3763
3764 if insert {
3765 assert_eq!(fragments_after.len(), 5);
3766 } else {
3767 assert_eq!(fragments_after.len(), 4);
3768 }
3769 } else {
3770 let ids_after: Vec<u64> = fragments_after.iter().map(|f| f.id).collect();
3778 assert_eq!(
3779 fragments_after.len(),
3780 4,
3781 "expected [frag 0, frag 2, frag 3, new frag], got {:?}",
3782 ids_after
3783 );
3784 assert_eq!(
3785 fragments_before[0], fragments_after[0],
3786 "frag 0 (untouched) should be identical"
3787 );
3788 assert!(
3789 !ids_after.contains(&1),
3790 "frag 1 was fully matched by source and should have been removed"
3791 );
3792 assert!(
3793 ids_after.contains(&2),
3794 "frag 2 was only partially matched and should still be present"
3795 );
3796 assert!(
3797 ids_after.contains(&3),
3798 "frag 3 (untouched) should still be present"
3799 );
3800 }
3801
3802 let data = ds
3809 .scan()
3810 .scan_in_order(true)
3811 .try_into_batch()
3812 .await
3813 .unwrap();
3814 assert_eq!(data.num_rows(), if insert { 1024 + 2 } else { 1024 });
3815 assert_eq!(data.num_columns(), 3);
3816
3817 use std::collections::HashMap;
3818 let other_col = data
3819 .column_by_name("other")
3820 .unwrap()
3821 .as_any()
3822 .downcast_ref::<arrow_array::StringArray>()
3823 .unwrap();
3824 let value_col = data
3825 .column_by_name("value")
3826 .unwrap()
3827 .as_any()
3828 .downcast_ref::<UInt32Array>()
3829 .unwrap();
3830 let key_col = data
3831 .column_by_name("key")
3832 .unwrap()
3833 .as_any()
3834 .downcast_ref::<arrow_array::StringArray>()
3835 .unwrap();
3836 let mut row_by_key: HashMap<String, (u32, String)> = HashMap::new();
3837 for i in 0..data.num_rows() {
3838 row_by_key.insert(
3839 key_col.value(i).to_string(),
3840 (value_col.value(i), other_col.value(i).to_string()),
3841 );
3842 }
3843
3844 let orig_batch_schema = new_data.schema();
3846 assert_eq!(orig_batch_schema.field(0).name(), "key");
3847 assert_eq!(orig_batch_schema.field(1).name(), "value");
3848 let new_keys = new_data
3849 .column(0)
3850 .as_any()
3851 .downcast_ref::<arrow_array::StringArray>()
3852 .unwrap();
3853 let new_values = new_data
3854 .column(1)
3855 .as_any()
3856 .downcast_ref::<UInt32Array>()
3857 .unwrap();
3858 for i in 0..(new_data.num_rows() - 2) {
3861 let key = new_keys.value(i).to_string();
3862 let (value, other) = row_by_key
3863 .get(&key)
3864 .unwrap_or_else(|| panic!("updated key {} missing from result", key));
3865 assert_eq!(*value, new_values.value(i));
3866 assert!(
3867 !other.is_empty(),
3868 "updated row for key {} should retain its original `other` value",
3869 key
3870 );
3871 }
3872 for i in (new_data.num_rows() - 2)..new_data.num_rows() {
3874 let key = new_keys.value(i).to_string();
3875 let found = row_by_key.get(&key);
3876 if insert {
3877 let (value, _) =
3878 found.unwrap_or_else(|| panic!("inserted key {} missing from result", key));
3879 assert_eq!(*value, new_values.value(i));
3880 } else {
3881 assert!(
3882 found.is_none(),
3883 "unmatched source row for key {} must not be present when insert=false",
3884 key
3885 );
3886 }
3887 }
3888 }
3889
3890 #[tokio::test]
3896 async fn test_merge_insert_subcols_v2_explain_plan() {
3897 let Fixtures { ds, new_data } = Box::pin(setup(false)).await;
3898
3899 let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()])
3900 .unwrap()
3901 .when_matched(WhenMatched::UpdateAll)
3902 .when_not_matched(WhenNotMatched::DoNothing)
3903 .try_build()
3904 .unwrap();
3905
3906 let source_schema: Schema = new_data.schema().as_ref().clone();
3907 let plan = job
3908 .explain_plan(Some(&source_schema), false)
3909 .await
3910 .expect("explain_plan must succeed for partial-schema upsert on v2");
3911
3912 assert!(
3918 plan.contains("MergeInsert: on=[key]"),
3919 "expected MergeInsert extension node in plan (v2 marker), got: {}",
3920 plan
3921 );
3922 assert!(
3923 plan.contains("HashJoinExec"),
3924 "expected HashJoinExec in plan, got: {}",
3925 plan
3926 );
3927 assert!(
3932 plan.contains("LanceRead") && plan.contains("projection=[other"),
3933 "target-side scan should include the filled `other` column: {}",
3934 plan
3935 );
3936 assert!(
3937 plan.contains("other@0 as other"),
3938 "expected post-join projection to carry `other` from the target side: {}",
3939 plan
3940 );
3941 }
3942
3943 #[tokio::test]
3948 async fn test_merge_insert_subcols_v2_rejects_non_nullable_insert() {
3949 let full_schema = Arc::new(Schema::new(vec![
3952 Field::new("key", DataType::Utf8, false),
3953 Field::new("value", DataType::UInt32, true),
3954 Field::new("other", DataType::Utf8, false),
3955 ]));
3956 let full_batch = RecordBatch::try_new(
3957 full_schema.clone(),
3958 vec![
3959 Arc::new(StringArray::from(vec!["k0", "k1", "k2"])),
3960 Arc::new(UInt32Array::from(vec![0, 1, 2])),
3961 Arc::new(StringArray::from(vec!["a", "b", "c"])),
3962 ],
3963 )
3964 .unwrap();
3965 let ds = Dataset::write(
3966 Box::new(RecordBatchIterator::new([Ok(full_batch)], full_schema)),
3967 "memory://",
3968 None,
3969 )
3970 .await
3971 .unwrap();
3972 let ds = Arc::new(ds);
3973
3974 let partial_schema = Arc::new(Schema::new(vec![
3976 Field::new("key", DataType::Utf8, false),
3977 Field::new("value", DataType::UInt32, true),
3978 ]));
3979 let partial_batch = RecordBatch::try_new(
3980 partial_schema.clone(),
3981 vec![
3982 Arc::new(StringArray::from(vec!["k1", "k_new"])),
3983 Arc::new(UInt32Array::from(vec![11, 99])),
3984 ],
3985 )
3986 .unwrap();
3987 let reader = Box::new(RecordBatchIterator::new(
3988 [Ok(partial_batch)],
3989 partial_schema,
3990 ));
3991
3992 let res = MergeInsertBuilder::try_new(ds, vec!["key".to_string()])
3993 .unwrap()
3994 .when_matched(WhenMatched::UpdateAll)
3995 .when_not_matched(WhenNotMatched::InsertAll)
3996 .try_build()
3997 .unwrap()
3998 .execute_reader(reader)
3999 .await;
4000
4001 match res {
4002 Err(Error::InvalidInput { source, .. }) => {
4003 let msg = source.to_string();
4004 assert!(
4005 msg.contains("partial-schema")
4006 && msg.contains("non-nullable")
4007 && msg.contains("\"other\""),
4008 "expected descriptive partial-schema / non-nullable error naming \
4009 the `other` column, got: {}",
4010 msg
4011 );
4012 }
4013 other => panic!(
4014 "expected InvalidInput error for non-nullable missing column on insert path, got: {:?}",
4015 other
4016 ),
4017 }
4018 }
4019
4020 #[tokio::test]
4028 async fn test_merge_insert_subcols_v2_camel_case_column() {
4029 let full_schema = Arc::new(Schema::new(vec![
4032 Field::new("userId", DataType::Utf8, false),
4033 Field::new("score", DataType::UInt32, true),
4034 Field::new("extraData", DataType::Utf8, true),
4035 ]));
4036 let full_batch = RecordBatch::try_new(
4037 full_schema.clone(),
4038 vec![
4039 Arc::new(StringArray::from(vec!["u1", "u2", "u3"])),
4040 Arc::new(UInt32Array::from(vec![10, 20, 30])),
4041 Arc::new(StringArray::from(vec!["a", "b", "c"])),
4042 ],
4043 )
4044 .unwrap();
4045 let ds = Dataset::write(
4046 Box::new(RecordBatchIterator::new([Ok(full_batch)], full_schema)),
4047 "memory://",
4048 None,
4049 )
4050 .await
4051 .unwrap();
4052 let ds = Arc::new(ds);
4053
4054 let partial_schema = Arc::new(Schema::new(vec![
4056 Field::new("userId", DataType::Utf8, false),
4057 Field::new("score", DataType::UInt32, true),
4058 ]));
4059 let partial_batch = RecordBatch::try_new(
4060 partial_schema.clone(),
4061 vec![
4062 Arc::new(StringArray::from(vec!["u2", "u_new"])),
4063 Arc::new(UInt32Array::from(vec![22, 99])),
4064 ],
4065 )
4066 .unwrap();
4067 let reader = Box::new(RecordBatchIterator::new(
4068 [Ok(partial_batch)],
4069 partial_schema,
4070 ));
4071
4072 let job = MergeInsertBuilder::try_new(ds.clone(), vec!["userId".to_string()])
4073 .unwrap()
4074 .when_matched(WhenMatched::UpdateAll)
4075 .when_not_matched(WhenNotMatched::InsertAll)
4076 .try_build()
4077 .unwrap();
4078 let (updated_ds, stats) = job
4079 .execute_reader(reader)
4080 .await
4081 .expect("camelCase partial-schema upsert must succeed on v2");
4082
4083 assert_eq!(stats.num_updated_rows, 1);
4084 assert_eq!(stats.num_inserted_rows, 1);
4085 assert_eq!(stats.num_deleted_rows, 0);
4086
4087 let data = updated_ds
4090 .scan()
4091 .scan_in_order(true)
4092 .try_into_batch()
4093 .await
4094 .unwrap();
4095 assert_eq!(data.num_rows(), 4);
4096 assert_eq!(data.num_columns(), 3);
4097
4098 let user_ids = data
4099 .column_by_name("userId")
4100 .expect("camelCase join key column must be present in result")
4101 .as_any()
4102 .downcast_ref::<StringArray>()
4103 .unwrap();
4104 let scores = data
4105 .column_by_name("score")
4106 .unwrap()
4107 .as_any()
4108 .downcast_ref::<UInt32Array>()
4109 .unwrap();
4110 let extra = data
4111 .column_by_name("extraData")
4112 .expect("camelCase omitted column must be present in result")
4113 .as_any()
4114 .downcast_ref::<StringArray>()
4115 .unwrap();
4116
4117 let mut by_user: std::collections::HashMap<String, (u32, Option<String>)> =
4118 std::collections::HashMap::new();
4119 for i in 0..data.num_rows() {
4120 let extra_val = if extra.is_null(i) {
4121 None
4122 } else {
4123 Some(extra.value(i).to_string())
4124 };
4125 by_user.insert(user_ids.value(i).to_string(), (scores.value(i), extra_val));
4126 }
4127
4128 assert_eq!(by_user["u1"], (10, Some("a".to_string())));
4130 assert_eq!(by_user["u3"], (30, Some("c".to_string())));
4131 assert_eq!(
4133 by_user["u2"],
4134 (22, Some("b".to_string())),
4135 "partial-schema update must preserve camelCase `extraData` from the target side of the join"
4136 );
4137 assert_eq!(
4139 by_user["u_new"],
4140 (99, None),
4141 "partial-schema insert must produce NULL for omitted camelCase column"
4142 );
4143 }
4144
4145 #[tokio::test]
4152 async fn test_merge_insert_subcols_v2_bloom_filter() {
4153 let schema = Arc::new(Schema::new(vec![
4154 Field::new("id", DataType::UInt32, false).with_metadata(
4155 vec![(
4156 "lance-schema:unenforced-primary-key".to_string(),
4157 "true".to_string(),
4158 )]
4159 .into_iter()
4160 .collect(),
4161 ),
4162 Field::new("value", DataType::UInt32, true),
4163 Field::new("tag", DataType::Utf8, true),
4164 ]));
4165 let initial = RecordBatch::try_new(
4166 schema.clone(),
4167 vec![
4168 Arc::new(UInt32Array::from(vec![0, 1, 2])),
4169 Arc::new(UInt32Array::from(vec![0, 0, 0])),
4170 Arc::new(StringArray::from(vec!["a", "b", "c"])),
4171 ],
4172 )
4173 .unwrap();
4174 let dataset = InsertBuilder::new("memory://")
4175 .execute(vec![initial])
4176 .await
4177 .unwrap();
4178 let dataset = Arc::new(dataset);
4179
4180 let partial_schema = Arc::new(Schema::new(vec![
4182 Field::new("id", DataType::UInt32, false).with_metadata(
4183 vec![(
4184 "lance-schema:unenforced-primary-key".to_string(),
4185 "true".to_string(),
4186 )]
4187 .into_iter()
4188 .collect(),
4189 ),
4190 Field::new("value", DataType::UInt32, true),
4191 ]));
4192 let partial = RecordBatch::try_new(
4193 partial_schema.clone(),
4194 vec![
4195 Arc::new(UInt32Array::from(vec![1, 5])), Arc::new(UInt32Array::from(vec![42, 99])),
4197 ],
4198 )
4199 .unwrap();
4200 let stream = RecordBatchStreamAdapter::new(
4201 partial_schema,
4202 futures::stream::iter(vec![Ok(partial)]),
4203 );
4204
4205 let UncommittedMergeInsert { transaction, .. } =
4206 MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
4207 .unwrap()
4208 .when_matched(WhenMatched::UpdateAll)
4209 .when_not_matched(WhenNotMatched::InsertAll)
4210 .try_build()
4211 .unwrap()
4212 .execute_uncommitted(Box::pin(stream) as SendableRecordBatchStream)
4213 .await
4214 .unwrap();
4215
4216 let committed = CommitBuilder::new(dataset.clone())
4220 .execute(transaction)
4221 .await
4222 .unwrap();
4223 let tx_path = committed
4224 .manifest()
4225 .transaction_file
4226 .clone()
4227 .expect("transaction file must be written");
4228 let tx_read =
4229 read_transaction_file(dataset.object_store.as_ref(), &dataset.base, &tx_path)
4230 .await
4231 .unwrap();
4232 match &tx_read.operation {
4233 Operation::Update {
4234 inserted_rows_filter,
4235 ..
4236 } => {
4237 let filter = inserted_rows_filter
4238 .as_ref()
4239 .expect("partial-schema upsert on a PK must emit a bloom filter");
4240 assert_eq!(filter.field_ids.len(), 1);
4242 }
4243 other => panic!("expected Operation::Update, got: {:?}", other),
4244 }
4245 }
4246 }
4247
4248 #[cfg(not(windows))]
4251 #[rstest::rstest]
4252 #[case::all_success(Duration::from_secs(100_000))]
4253 #[case::timeout(Duration::from_millis(200))]
4254 #[tokio::test]
4255 async fn test_merge_insert_concurrency(#[case] timeout: Duration) {
4256 let schema = Arc::new(Schema::new(vec![
4257 Field::new("id", DataType::UInt32, false),
4258 Field::new("value", DataType::UInt32, false),
4259 ]));
4260 let concurrency = 10;
4264 let initial_data = RecordBatch::try_new(
4265 schema.clone(),
4266 vec![
4267 Arc::new(UInt32Array::from_iter_values(0..concurrency)),
4268 Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n(
4269 0,
4270 concurrency as usize,
4271 ))),
4272 ],
4273 )
4274 .unwrap();
4275
4276 let throttled = Arc::new(ThrottledStoreWrapper {
4278 config: ThrottleConfig {
4279 wait_list_per_call: Duration::from_millis(20),
4281 wait_get_per_call: Duration::from_millis(20),
4282 wait_put_per_call: Duration::from_millis(20),
4283 ..Default::default()
4284 },
4285 });
4286 let session = Arc::new(Session::default());
4287
4288 let mut dataset = InsertBuilder::new("memory://")
4289 .with_params(&WriteParams {
4290 store_params: Some(ObjectStoreParams {
4291 object_store_wrapper: Some(throttled.clone()),
4292 ..Default::default()
4293 }),
4294 session: Some(session.clone()),
4295 ..Default::default()
4296 })
4297 .execute(vec![initial_data])
4298 .await
4299 .unwrap();
4300
4301 let barrier = Arc::new(Barrier::new(concurrency as usize));
4306 let mut handles = Vec::new();
4307 for i in 0..concurrency {
4308 let session_ref = session.clone();
4309 let schema_ref = schema.clone();
4310 let barrier_ref = barrier.clone();
4311 let throttled_ref = throttled.clone();
4312 let handle = tokio::task::spawn(async move {
4313 let dataset = DatasetBuilder::from_uri("memory://")
4314 .with_read_params(ReadParams {
4315 store_options: Some(ObjectStoreParams {
4316 object_store_wrapper: Some(throttled_ref.clone()),
4317 ..Default::default()
4318 }),
4319 session: Some(session_ref.clone()),
4320 ..Default::default()
4321 })
4322 .load()
4323 .await
4324 .unwrap();
4325 let dataset = Arc::new(dataset);
4326
4327 let new_data = RecordBatch::try_new(
4328 schema_ref.clone(),
4329 vec![
4330 Arc::new(UInt32Array::from(vec![i])),
4331 Arc::new(UInt32Array::from(vec![1])),
4332 ],
4333 )
4334 .unwrap();
4335 let source = Box::new(RecordBatchIterator::new([Ok(new_data)], schema_ref.clone()));
4336
4337 let job = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()])
4338 .unwrap()
4339 .when_matched(WhenMatched::UpdateAll)
4340 .when_not_matched(WhenNotMatched::InsertAll)
4341 .conflict_retries(100)
4342 .retry_timeout(timeout)
4343 .try_build()
4344 .unwrap();
4345 barrier_ref.wait().await;
4346
4347 job.execute_reader(source)
4348 .await
4349 .map(|(_ds, stats)| stats.num_attempts)
4350 });
4351 handles.push(handle);
4352 }
4353
4354 let results = try_join_all(handles).await.unwrap();
4355
4356 for attempts in results.iter() {
4357 match attempts {
4358 Ok(attempts) => {
4359 assert!(*attempts <= 10, "Attempt count should be <= 10");
4360 }
4361 Err(err) => {
4362 assert!(
4366 matches!(err, Error::TooMuchWriteContention { message, .. } if message.contains("failed on retry_timeout")),
4367 "Expected TooMuchWriteContention error, got: {:?}",
4368 err
4369 );
4370 }
4371 }
4372 }
4373
4374 if timeout.as_secs() > 10 {
4375 dataset.checkout_latest().await.unwrap();
4376 let batches = dataset.scan().try_into_batch().await.unwrap();
4377
4378 let values = batches["value"].as_primitive::<UInt32Type>();
4379 assert!(
4380 values.values().iter().all(|&v| v == 1),
4381 "All values should be 1 after merge insert. Got: {:?}",
4382 values
4383 );
4384 }
4385 }
4386
4387 #[tokio::test]
4388 async fn test_merge_insert_large_concurrent() {
4389 let schema = Arc::new(Schema::new(vec![
4390 Field::new("id", DataType::UInt32, false),
4391 Field::new("value", DataType::UInt32, false),
4392 ]));
4393 let num_rows = 10;
4394 let initial_data = RecordBatch::try_new(
4395 schema.clone(),
4396 vec![
4397 Arc::new(UInt32Array::from_iter_values(0..num_rows)),
4398 Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n(
4399 0,
4400 num_rows as usize,
4401 ))),
4402 ],
4403 )
4404 .unwrap();
4405
4406 let throttled = Arc::new(ThrottledStoreWrapper {
4408 config: ThrottleConfig {
4409 wait_list_per_call: Duration::from_millis(10),
4410 wait_get_per_call: Duration::from_millis(10),
4411 ..Default::default()
4412 },
4413 });
4414 let session = Arc::new(Session::default());
4415
4416 let dataset = InsertBuilder::new("memory://")
4417 .with_params(&WriteParams {
4418 store_params: Some(ObjectStoreParams {
4419 object_store_wrapper: Some(throttled.clone()),
4420 ..Default::default()
4421 }),
4422 session: Some(session.clone()),
4423 ..Default::default()
4424 })
4425 .execute(vec![initial_data])
4426 .await
4427 .unwrap();
4428 let dataset = Arc::new(dataset);
4429
4430 let new_data1 = RecordBatch::try_new(
4432 schema.clone(),
4433 vec![
4434 Arc::new(UInt32Array::from(vec![1])),
4435 Arc::new(UInt32Array::from(vec![1])),
4436 ],
4437 )
4438 .unwrap();
4439 let UncommittedMergeInsert {
4440 transaction: transaction1,
4441 ..
4442 } = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
4443 .unwrap()
4444 .when_matched(WhenMatched::UpdateAll)
4445 .when_not_matched(WhenNotMatched::InsertAll)
4446 .try_build()
4447 .unwrap()
4448 .execute_uncommitted(RecordBatchIterator::new(
4449 vec![Ok(new_data1)],
4450 schema.clone(),
4451 ))
4452 .await
4453 .unwrap();
4454
4455 let new_data2 = RecordBatch::try_new(
4457 schema.clone(),
4458 vec![
4459 Arc::new(UInt32Array::from_iter_values(0..1000)),
4460 Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n(2, 1000))),
4461 ],
4462 )
4463 .unwrap();
4464 let notify = Arc::new(Notify::new());
4465 let source = RecordBatchIterator::new(
4466 (0..10)
4467 .map(|i| {
4468 let batch = new_data2.slice(i * 100, 100);
4469 if i == 9 {
4470 notify.notify_one();
4471 }
4472 Ok(batch)
4473 })
4474 .collect::<Vec<_>>(),
4475 schema.clone(),
4476 );
4477 let dataset2 = DatasetBuilder::from_uri("memory://")
4478 .with_read_params(ReadParams {
4479 store_options: Some(ObjectStoreParams {
4480 object_store_wrapper: Some(throttled.clone()),
4481 ..Default::default()
4482 }),
4483 session: Some(session.clone()),
4484 ..Default::default()
4485 })
4486 .load()
4487 .await
4488 .unwrap();
4489 let job = MergeInsertBuilder::try_new(Arc::new(dataset2), vec!["id".to_string()])
4490 .unwrap()
4491 .when_matched(WhenMatched::UpdateAll)
4492 .when_not_matched(WhenNotMatched::InsertAll)
4493 .try_build()
4494 .unwrap()
4495 .execute_reader(source);
4496 let task = tokio::task::spawn(job);
4497
4498 notify.notified().await;
4502 let mut dataset = CommitBuilder::new(dataset)
4503 .execute(transaction1)
4504 .await
4505 .unwrap();
4506
4507 task.await.unwrap().unwrap();
4508 dataset.checkout_latest().await.unwrap();
4509
4510 let batches = dataset.scan().try_into_batch().await.unwrap();
4511 let values = batches["value"].as_primitive::<UInt32Type>();
4512 assert!(
4513 values.values().iter().all(|&v| v == 2),
4514 "All values should be 1 after merge insert. Got: {:?}",
4515 values
4516 );
4517 }
4518
4519 #[tokio::test]
4520 async fn test_merge_insert_updates_indices() {
4521 let test_dataset = async || {
4522 let mut dataset = lance_datagen::gen_batch()
4523 .col("id", array::step::<UInt32Type>())
4524 .col("value", array::step::<UInt32Type>())
4525 .col("other_value", array::step::<UInt32Type>())
4526 .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(20))
4527 .await
4528 .unwrap();
4529
4530 dataset
4531 .create_index(
4532 &["id"],
4533 IndexType::BTree,
4534 None,
4535 &ScalarIndexParams::default(),
4536 false,
4537 )
4538 .await
4539 .unwrap();
4540 dataset
4541 .create_index(
4542 &["value"],
4543 IndexType::BTree,
4544 None,
4545 &ScalarIndexParams::default(),
4546 false,
4547 )
4548 .await
4549 .unwrap();
4550 dataset
4551 .create_index(
4552 &["other_value"],
4553 IndexType::BTree,
4554 None,
4555 &ScalarIndexParams::default(),
4556 false,
4557 )
4558 .await
4559 .unwrap();
4560 Arc::new(dataset)
4561 };
4562
4563 let check_indices = async |dataset: &Dataset, id_frags: &[u32], value_frags: &[u32]| {
4564 let id_index = dataset
4565 .load_scalar_index(IndexCriteria::default().with_name("id_idx"))
4566 .await
4567 .unwrap();
4568
4569 if id_frags.is_empty() {
4570 assert!(id_index.is_none());
4571 } else {
4572 let id_index = id_index.unwrap();
4573 let id_frags_bitmap = RoaringBitmap::from_iter(id_frags.iter().copied());
4574 let effective_bitmap = id_index
4576 .effective_fragment_bitmap(&dataset.fragment_bitmap)
4577 .unwrap();
4578 assert_eq!(effective_bitmap, id_frags_bitmap);
4579 }
4580
4581 let value_index = dataset
4582 .load_scalar_index(IndexCriteria::default().with_name("value_idx"))
4583 .await
4584 .unwrap();
4585
4586 if value_frags.is_empty() {
4587 assert!(value_index.is_none());
4588 } else {
4589 let value_index = value_index.unwrap();
4590 let value_frags_bitmap = RoaringBitmap::from_iter(value_frags.iter().copied());
4591 let effective_bitmap = value_index
4593 .effective_fragment_bitmap(&dataset.fragment_bitmap)
4594 .unwrap();
4595 assert_eq!(effective_bitmap, value_frags_bitmap);
4596 }
4597
4598 let other_value_index = dataset
4599 .load_scalar_index(IndexCriteria::default().with_name("other_value_idx"))
4600 .await
4601 .unwrap()
4602 .unwrap();
4603
4604 let effective_bitmap = other_value_index
4607 .effective_fragment_bitmap(&dataset.fragment_bitmap)
4608 .unwrap();
4609
4610 let index_bitmap = other_value_index.fragment_bitmap.as_ref().unwrap();
4615 let expected_bitmap = index_bitmap & dataset.fragment_bitmap.as_ref();
4616 assert_eq!(
4617 effective_bitmap, expected_bitmap,
4618 "other_value index effective bitmap should be intersection. index_bitmap: {:?}, dataset_fragments: {:?}, effective_bitmap: {:?}",
4619 index_bitmap, dataset.fragment_bitmap, effective_bitmap
4620 );
4621 };
4622
4623 let dataset = test_dataset().await;
4624
4625 check_indices(&dataset, &[0, 1, 2, 3], &[0, 1, 2, 3]).await;
4627
4628 let merge_insert = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()])
4631 .unwrap()
4632 .when_matched(WhenMatched::UpdateAll)
4633 .when_not_matched(WhenNotMatched::InsertAll)
4634 .when_not_matched(WhenNotMatched::InsertAll)
4635 .try_build()
4636 .unwrap();
4637
4638 let (dataset, _) = merge_insert
4639 .execute_reader(
4640 lance_datagen::gen_batch()
4641 .col("id", array::step_custom::<UInt32Type>(50, 1))
4642 .col("value", array::step_custom::<UInt32Type>(50, 1))
4643 .col("other_value", array::step_custom::<UInt32Type>(50, 1))
4644 .into_df_stream(RowCount::from(40), BatchCount::from(1)),
4645 )
4646 .await
4647 .unwrap();
4648
4649 check_indices(&dataset, &[0, 1, 2], &[0, 1, 2]).await;
4651
4652 let dataset = test_dataset().await;
4654
4655 let merge_insert = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()])
4658 .unwrap()
4659 .when_matched(WhenMatched::UpdateAll)
4660 .when_not_matched(WhenNotMatched::InsertAll)
4661 .when_not_matched(WhenNotMatched::InsertAll)
4662 .try_build()
4663 .unwrap();
4664
4665 let (dataset, _) = merge_insert
4666 .execute_reader(
4667 lance_datagen::gen_batch()
4668 .col("id", array::step_custom::<UInt32Type>(50, 1))
4669 .col("value", array::step_custom::<UInt32Type>(50, 1))
4670 .into_df_stream(RowCount::from(40), BatchCount::from(1)),
4671 )
4672 .await
4673 .unwrap();
4674
4675 check_indices(&dataset, &[0, 1], &[0, 1]).await;
4682
4683 let dataset = test_dataset().await;
4686
4687 let merge_insert = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()])
4690 .unwrap()
4691 .when_matched(WhenMatched::UpdateAll)
4692 .when_not_matched(WhenNotMatched::InsertAll)
4693 .when_not_matched(WhenNotMatched::InsertAll)
4694 .try_build()
4695 .unwrap();
4696
4697 let (dataset, _) = merge_insert
4698 .execute_reader(
4699 lance_datagen::gen_batch()
4700 .col("id", array::step_custom::<UInt32Type>(10, 1))
4701 .col("value", array::step_custom::<UInt32Type>(10, 1))
4702 .into_df_stream(RowCount::from(80), BatchCount::from(1)),
4703 )
4704 .await
4705 .unwrap();
4706
4707 check_indices(&dataset, &[], &[]).await;
4708 }
4709
4710 #[tokio::test]
4711 async fn test_upsert_concurrent_full_frag() {
4712 let schema = Arc::new(Schema::new(vec![
4713 Field::new("id", DataType::UInt32, false),
4714 Field::new("value", DataType::UInt32, false),
4715 ]));
4716 let initial_data = RecordBatch::try_new(
4717 schema.clone(),
4718 vec![
4719 Arc::new(UInt32Array::from(vec![0, 1])),
4720 Arc::new(UInt32Array::from(vec![0, 0])),
4721 ],
4722 )
4723 .unwrap();
4724
4725 let throttled = Arc::new(ThrottledStoreWrapper {
4727 config: ThrottleConfig {
4728 wait_list_per_call: Duration::from_millis(5),
4729 wait_get_per_call: Duration::from_millis(5),
4730 wait_put_per_call: Duration::from_millis(5),
4731 ..Default::default()
4732 },
4733 });
4734 let session = Arc::new(Session::default());
4735
4736 let mut dataset = InsertBuilder::new("memory://")
4737 .with_params(&WriteParams {
4738 store_params: Some(ObjectStoreParams {
4739 object_store_wrapper: Some(throttled.clone()),
4740 ..Default::default()
4741 }),
4742 session: Some(session.clone()),
4743 ..Default::default()
4744 })
4745 .execute(vec![initial_data])
4746 .await
4747 .unwrap();
4748
4749 let barrier = Arc::new(Barrier::new(2));
4752 let mut handles = Vec::new();
4753 for i in 0..2 {
4754 let new_data = RecordBatch::try_new(
4755 schema.clone(),
4756 vec![
4757 Arc::new(UInt32Array::from(vec![i])),
4758 Arc::new(UInt32Array::from(vec![1])),
4759 ],
4760 )
4761 .unwrap();
4762 let source = Box::new(RecordBatchIterator::new([Ok(new_data)], schema.clone()));
4763
4764 let dataset_ref = Arc::new(dataset.clone());
4765 let barrier = barrier.clone();
4766 let handle = tokio::spawn(async move {
4767 barrier.wait().await;
4768 MergeInsertBuilder::try_new(dataset_ref, vec!["id".to_string()])
4769 .unwrap()
4770 .when_matched(WhenMatched::UpdateAll)
4771 .when_not_matched(WhenNotMatched::InsertAll)
4772 .try_build()
4773 .unwrap()
4774 .execute_reader(source)
4775 .await
4776 .unwrap();
4777 });
4778 handles.push(handle);
4779 }
4780 try_join_all(handles).await.unwrap();
4781
4782 dataset.checkout_latest().await.unwrap();
4783 assert!(
4784 dataset
4785 .get_fragments()
4786 .iter()
4787 .all(|f| f.metadata().num_rows().unwrap() > 0),
4788 "No fragments should have zero rows after upsert"
4789 );
4790
4791 let batches = dataset.scan().try_into_batch().await.unwrap();
4792 let values = batches["value"].as_primitive::<UInt32Type>();
4793 assert!(
4794 values.values().iter().all(|&v| v == 1),
4795 "All values should be 1 after merge insert. Got: {:?}",
4796 values
4797 );
4798 }
4799
4800 #[tokio::test]
4801 async fn test_plan_upsert() {
4802 let data = lance_datagen::gen_batch()
4803 .with_seed(Seed::from(1))
4804 .col("value", array::step::<UInt32Type>())
4805 .col("key", array::rand_pseudo_uuid_hex());
4806 let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(32));
4807 let _schema = data.schema();
4808
4809 let ds = Dataset::write(data, "memory://", None).await.unwrap();
4811
4812 let merge_insert_job =
4814 crate::dataset::MergeInsertBuilder::try_new(Arc::new(ds), vec!["key".to_string()])
4815 .unwrap()
4816 .when_matched(crate::dataset::WhenMatched::UpdateAll)
4817 .try_build()
4818 .unwrap();
4819
4820 let new_data = lance_datagen::gen_batch()
4822 .with_seed(Seed::from(2))
4823 .col("value", array::step::<UInt32Type>())
4824 .col("key", array::rand_pseudo_uuid_hex());
4825 let new_data = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16));
4826 let new_data_stream = reader_to_stream(Box::new(new_data));
4827
4828 let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap();
4829
4830 assert_plan_node_equals(
4839 plan,
4840 "MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep
4841 CoalescePartitionsExec
4842 ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
4843 HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
4844 LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, \
4845 row_id=true, row_addr=true, full_filter=--, refine_filter=--
4846 RepartitionExec: partitioning=RoundRobinBatch(...), input_partitions=1
4847 ProjectionExec: expr=[value@0 as value, key@1 as key, true as __merge_source_sentinel]
4848 StreamingTableExec: partition_sizes=1, projection=[value, key]"
4849 ).await.unwrap();
4850 }
4851
4852 #[tokio::test]
4853 async fn test_fast_path_update_only() {
4854 let data = lance_datagen::gen_batch()
4855 .with_seed(Seed::from(1))
4856 .col("value", array::step::<UInt32Type>())
4857 .col("key", array::rand_pseudo_uuid_hex());
4858 let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(32));
4859
4860 let ds = Dataset::write(data, "memory://", None).await.unwrap();
4862
4863 let merge_insert_job =
4865 crate::dataset::MergeInsertBuilder::try_new(Arc::new(ds), vec!["key".to_string()])
4866 .unwrap()
4867 .when_matched(crate::dataset::WhenMatched::UpdateAll)
4868 .when_not_matched(crate::dataset::WhenNotMatched::DoNothing)
4869 .try_build()
4870 .unwrap();
4871
4872 let new_data = lance_datagen::gen_batch()
4874 .with_seed(Seed::from(2))
4875 .col("value", array::step::<UInt32Type>())
4876 .col("key", array::rand_pseudo_uuid_hex());
4877 let new_data = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16));
4878 let new_data_stream = reader_to_stream(Box::new(new_data));
4879
4880 let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap();
4882
4883 assert_plan_node_equals(
4887 plan,
4888 "MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=DoNothing, when_not_matched_by_source=Keep
4889 CoalescePartitionsExec
4890 ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
4891 HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
4892 LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
4893 RepartitionExec...
4894 ProjectionExec: expr=[value@0 as value, key@1 as key, true as __merge_source_sentinel]
4895 StreamingTableExec: partition_sizes=1, projection=[value, key]"
4896 ).await.unwrap();
4897 }
4898
4899 #[tokio::test]
4900 async fn test_fast_path_conditional_update() {
4901 let data = lance_datagen::gen_batch()
4902 .with_seed(Seed::from(1))
4903 .col("value", array::step::<UInt32Type>())
4904 .col("key", array::rand_pseudo_uuid_hex());
4905 let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(32));
4906
4907 let ds = Dataset::write(data, "memory://", None).await.unwrap();
4909
4910 let merge_insert_job = crate::dataset::MergeInsertBuilder::try_new(
4912 Arc::new(ds.clone()),
4913 vec!["key".to_string()],
4914 )
4915 .unwrap()
4916 .when_matched(crate::dataset::WhenMatched::update_if(&ds, "source.value > 20").unwrap())
4917 .when_not_matched(crate::dataset::WhenNotMatched::DoNothing)
4918 .try_build()
4919 .unwrap();
4920
4921 let new_data = lance_datagen::gen_batch()
4923 .with_seed(Seed::from(2))
4924 .col("value", array::step::<UInt32Type>())
4925 .col("key", array::rand_pseudo_uuid_hex());
4926 let new_data_reader = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16));
4927 let new_data_stream = reader_to_stream(Box::new(new_data_reader));
4928
4929 let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap();
4930
4931 assert_plan_node_equals(
4934 plan,
4935 "MergeInsert: on=[key], when_matched=UpdateIf(source.value > 20), when_not_matched=DoNothing, when_not_matched_by_source=Keep
4936 CoalescePartitionsExec
4937 ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action]
4938 HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
4939 LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
4940 RepartitionExec...
4941 ProjectionExec: expr=[value@0 as value, key@1 as key, true as __merge_source_sentinel]
4942 StreamingTableExec: partition_sizes=1, projection=[value, key]"
4943 ).await.unwrap();
4944 }
4945
4946 #[tokio::test]
4953 async fn test_fast_path_find_or_create() {
4954 let data = lance_datagen::gen_batch()
4955 .with_seed(Seed::from(1))
4956 .col("value", array::step::<UInt32Type>())
4957 .col("key", array::rand_pseudo_uuid_hex());
4958 let data = data.into_reader_rows(RowCount::from(1024), BatchCount::from(32));
4959
4960 let ds = Dataset::write(data, "memory://", None).await.unwrap();
4962
4963 let merge_insert_job =
4966 crate::dataset::MergeInsertBuilder::try_new(Arc::new(ds), vec!["key".to_string()])
4967 .unwrap()
4968 .try_build()
4969 .unwrap();
4970
4971 let new_data = lance_datagen::gen_batch()
4973 .with_seed(Seed::from(2))
4974 .col("value", array::step::<UInt32Type>())
4975 .col("key", array::rand_pseudo_uuid_hex());
4976 let new_data = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16));
4977 let new_data_stream = reader_to_stream(Box::new(new_data));
4978
4979 let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap();
4982
4983 assert_plan_node_equals(
4988 plan,
4989 "MergeInsert: on=[key], when_matched=DoNothing, when_not_matched=InsertAll, when_not_matched_by_source=Keep
4990 CoalescePartitionsExec
4991 ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 ELSE 0 END as __action]
4992 HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
4993 LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
4994 RepartitionExec...
4995 ProjectionExec: expr=[value@0 as value, key@1 as key, true as __merge_source_sentinel]
4996 StreamingTableExec: partition_sizes=1, projection=[value, key]"
4997 )
4998 .await
4999 .unwrap();
5000 }
5001
5002 #[tokio::test]
5003 async fn test_skip_auto_cleanup() {
5004 let tmpdir = TempStrDir::default();
5005 let dataset_uri = format!("{}/{}", tmpdir, "test_dataset");
5006
5007 let data = lance_datagen::gen_batch()
5009 .with_seed(Seed::from(1))
5010 .col("id", array::step::<UInt32Type>())
5011 .into_reader_rows(RowCount::from(100), BatchCount::from(1));
5012
5013 let mut auto_cleanup_params = HashMap::new();
5014 auto_cleanup_params.insert("lance.auto_cleanup.interval".to_string(), "1".to_string());
5015 auto_cleanup_params.insert(
5016 "lance.auto_cleanup.older_than".to_string(),
5017 "0ms".to_string(),
5018 );
5019
5020 let write_params = WriteParams {
5021 mode: WriteMode::Create,
5022 auto_cleanup: Some(crate::dataset::AutoCleanupParams {
5023 interval: 1,
5024 older_than: chrono::TimeDelta::try_milliseconds(0).unwrap(),
5025 }),
5026 ..Default::default()
5027 };
5028
5029 MockClock::set_system_time(std::time::Duration::from_secs(1));
5031
5032 let dataset = Dataset::write(data, &dataset_uri, Some(write_params))
5033 .await
5034 .unwrap();
5035 assert_eq!(dataset.version().version, 1);
5036
5037 MockClock::set_system_time(std::time::Duration::from_secs(2));
5039
5040 let new_data = lance_datagen::gen_batch()
5042 .with_seed(Seed::from(2))
5043 .col("id", array::step::<UInt32Type>())
5044 .into_df_stream(RowCount::from(50), BatchCount::from(1));
5045
5046 let (dataset2, _) = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
5047 .unwrap()
5048 .when_matched(WhenMatched::UpdateAll)
5049 .when_not_matched(WhenNotMatched::InsertAll)
5050 .try_build()
5051 .unwrap()
5052 .execute(new_data)
5053 .await
5054 .unwrap();
5055
5056 assert_eq!(dataset2.version().version, 2);
5057
5058 MockClock::set_system_time(std::time::Duration::from_secs(3));
5060
5061 let new_data_extra = lance_datagen::gen_batch()
5063 .with_seed(Seed::from(4))
5064 .col("id", array::step::<UInt32Type>())
5065 .into_df_stream(RowCount::from(10), BatchCount::from(1));
5066
5067 let (dataset2_extra, _) =
5068 MergeInsertBuilder::try_new(dataset2.clone(), vec!["id".to_string()])
5069 .unwrap()
5070 .when_matched(WhenMatched::UpdateAll)
5071 .when_not_matched(WhenNotMatched::InsertAll)
5072 .try_build()
5073 .unwrap()
5074 .execute(new_data_extra)
5075 .await
5076 .unwrap();
5077
5078 assert_eq!(dataset2_extra.version().version, 3);
5079
5080 let ds_check1 = DatasetBuilder::from_uri(&dataset_uri).load().await.unwrap();
5082
5083 assert!(
5085 ds_check1.checkout_version(1).await.is_err(),
5086 "Version 1 should have been cleaned up"
5087 );
5088 assert!(
5090 ds_check1.checkout_version(2).await.is_ok(),
5091 "Version 2 should still exist"
5092 );
5093
5094 MockClock::set_system_time(std::time::Duration::from_secs(4));
5096
5097 let new_data2 = lance_datagen::gen_batch()
5099 .with_seed(Seed::from(3))
5100 .col("id", array::step::<UInt32Type>())
5101 .into_df_stream(RowCount::from(30), BatchCount::from(1));
5102
5103 let (dataset3, _) = MergeInsertBuilder::try_new(dataset2_extra, vec!["id".to_string()])
5104 .unwrap()
5105 .when_matched(WhenMatched::UpdateAll)
5106 .when_not_matched(WhenNotMatched::InsertAll)
5107 .skip_auto_cleanup(true) .try_build()
5109 .unwrap()
5110 .execute(new_data2)
5111 .await
5112 .unwrap();
5113
5114 assert_eq!(dataset3.version().version, 4);
5115
5116 let ds_check2 = DatasetBuilder::from_uri(&dataset_uri).load().await.unwrap();
5118
5119 assert!(
5121 ds_check2.checkout_version(2).await.is_ok(),
5122 "Version 2 should still exist because skip_auto_cleanup was enabled"
5123 );
5124 assert!(
5126 ds_check2.checkout_version(3).await.is_ok(),
5127 "Version 3 should still exist"
5128 );
5129 }
5130
5131 #[tokio::test]
5132 async fn test_transaction_inserted_rows_filter_roundtrip() {
5133 let schema = Arc::new(Schema::new(vec![
5135 Field::new("id", DataType::UInt32, false).with_metadata(
5136 vec![(
5137 "lance-schema:unenforced-primary-key".to_string(),
5138 "true".to_string(),
5139 )]
5140 .into_iter()
5141 .collect(),
5142 ),
5143 Field::new("value", DataType::UInt32, false),
5144 ]));
5145 let initial = RecordBatch::try_new(
5146 schema.clone(),
5147 vec![
5148 Arc::new(UInt32Array::from(vec![0, 1, 2])),
5149 Arc::new(UInt32Array::from(vec![0, 0, 0])),
5150 ],
5151 )
5152 .unwrap();
5153 let dataset = InsertBuilder::new("memory://")
5154 .execute(vec![initial])
5155 .await
5156 .unwrap();
5157 let dataset = Arc::new(dataset);
5158
5159 let new_batch = RecordBatch::try_new(
5161 schema.clone(),
5162 vec![
5163 Arc::new(UInt32Array::from(vec![1, 3])),
5164 Arc::new(UInt32Array::from(vec![2, 2])),
5165 ],
5166 )
5167 .unwrap();
5168 let stream = RecordBatchStreamAdapter::new(
5169 schema.clone(),
5170 futures::stream::iter(vec![Ok(new_batch)]),
5171 );
5172
5173 let UncommittedMergeInsert { transaction, .. } =
5174 MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
5175 .unwrap()
5176 .when_matched(WhenMatched::UpdateAll)
5177 .when_not_matched(WhenNotMatched::InsertAll)
5178 .try_build()
5179 .unwrap()
5180 .execute_uncommitted(Box::pin(stream) as SendableRecordBatchStream)
5181 .await
5182 .unwrap();
5183
5184 let committed = CommitBuilder::new(dataset.clone())
5186 .execute(transaction)
5187 .await
5188 .unwrap();
5189 let tx_path = committed.manifest().transaction_file.clone().unwrap();
5190 let tx_read = read_transaction_file(dataset.object_store.as_ref(), &dataset.base, &tx_path)
5191 .await
5192 .unwrap();
5193 if let Operation::Update {
5195 inserted_rows_filter,
5196 ..
5197 } = &tx_read.operation
5198 {
5199 assert!(inserted_rows_filter.is_some());
5200 let filter = inserted_rows_filter.as_ref().unwrap();
5201 assert_eq!(filter.field_ids.len(), 1);
5203 } else {
5204 panic!("Expected Operation::Update");
5205 }
5206 }
5207
5208 #[tokio::test]
5212 async fn test_inserted_rows_filter_bloom_conflict_detection_concurrent() {
5213 let schema = Arc::new(Schema::new(vec![
5215 Field::new("id", DataType::UInt32, false).with_metadata(
5216 vec![(
5217 "lance-schema:unenforced-primary-key".to_string(),
5218 "true".to_string(),
5219 )]
5220 .into_iter()
5221 .collect(),
5222 ),
5223 Field::new("value", DataType::UInt32, false),
5224 ]));
5225 let initial = RecordBatch::try_new(
5226 schema.clone(),
5227 vec![
5228 Arc::new(UInt32Array::from(vec![0, 1, 2, 3])),
5229 Arc::new(UInt32Array::from(vec![0, 0, 0, 0])),
5230 ],
5231 )
5232 .unwrap();
5233
5234 let dataset = InsertBuilder::new("memory://")
5235 .execute(vec![initial])
5236 .await
5237 .unwrap();
5238 let dataset = Arc::new(dataset);
5239
5240 let batch1 = RecordBatch::try_new(
5242 schema.clone(),
5243 vec![
5244 Arc::new(UInt32Array::from(vec![2])),
5245 Arc::new(UInt32Array::from(vec![1])),
5246 ],
5247 )
5248 .unwrap();
5249 let batch2 = RecordBatch::try_new(
5250 schema.clone(),
5251 vec![
5252 Arc::new(UInt32Array::from(vec![2])),
5253 Arc::new(UInt32Array::from(vec![2])),
5254 ],
5255 )
5256 .unwrap();
5257
5258 let b2 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
5260 .unwrap()
5261 .when_matched(WhenMatched::UpdateAll)
5262 .when_not_matched(WhenNotMatched::InsertAll)
5263 .conflict_retries(0)
5264 .try_build()
5265 .unwrap();
5266
5267 let s1 = RecordBatchStreamAdapter::new(
5269 schema.clone(),
5270 futures::stream::iter(vec![Ok(batch1.clone())]),
5271 );
5272 let b1 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
5273 .unwrap()
5274 .when_matched(WhenMatched::UpdateAll)
5275 .when_not_matched(WhenNotMatched::InsertAll)
5276 .try_build()
5277 .unwrap();
5278 let result1 = b1.execute(Box::pin(s1) as SendableRecordBatchStream).await;
5279 assert!(result1.is_ok(), "First merge insert should succeed");
5280
5281 let s2 = RecordBatchStreamAdapter::new(
5283 schema.clone(),
5284 futures::stream::iter(vec![Ok(batch2.clone())]),
5285 );
5286 let result2 = b2.execute(Box::pin(s2) as SendableRecordBatchStream).await;
5287
5288 assert!(
5290 matches!(result2, Err(crate::Error::TooMuchWriteContention { .. })),
5291 "Expected TooMuchWriteContention (retryable conflict exhausted), got: {:?}",
5292 result2
5293 );
5294 }
5295
5296 #[tokio::test]
5300 async fn test_concurrent_insert_same_new_key() {
5301 let schema = Arc::new(Schema::new(vec![
5303 Field::new("id", DataType::UInt32, false).with_metadata(
5304 vec![(
5305 "lance-schema:unenforced-primary-key".to_string(),
5306 "true".to_string(),
5307 )]
5308 .into_iter()
5309 .collect(),
5310 ),
5311 Field::new("value", DataType::UInt32, false),
5312 ]));
5313 let initial = RecordBatch::try_new(
5315 schema.clone(),
5316 vec![
5317 Arc::new(UInt32Array::from(vec![0, 1, 2, 3])),
5318 Arc::new(UInt32Array::from(vec![0, 0, 0, 0])),
5319 ],
5320 )
5321 .unwrap();
5322
5323 let dataset = InsertBuilder::new("memory://")
5324 .execute(vec![initial])
5325 .await
5326 .unwrap();
5327 let dataset = Arc::new(dataset);
5328
5329 let batch1 = RecordBatch::try_new(
5331 schema.clone(),
5332 vec![
5333 Arc::new(UInt32Array::from(vec![100])), Arc::new(UInt32Array::from(vec![1])),
5335 ],
5336 )
5337 .unwrap();
5338 let batch2 = RecordBatch::try_new(
5339 schema.clone(),
5340 vec![
5341 Arc::new(UInt32Array::from(vec![100])), Arc::new(UInt32Array::from(vec![2])),
5343 ],
5344 )
5345 .unwrap();
5346
5347 let b2 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
5349 .unwrap()
5350 .when_matched(WhenMatched::UpdateAll)
5351 .when_not_matched(WhenNotMatched::InsertAll)
5352 .conflict_retries(0)
5353 .try_build()
5354 .unwrap();
5355
5356 let s1 = RecordBatchStreamAdapter::new(
5358 schema.clone(),
5359 futures::stream::iter(vec![Ok(batch1.clone())]),
5360 );
5361 let b1 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
5362 .unwrap()
5363 .when_matched(WhenMatched::UpdateAll)
5364 .when_not_matched(WhenNotMatched::InsertAll)
5365 .try_build()
5366 .unwrap();
5367 let result1 = b1.execute(Box::pin(s1) as SendableRecordBatchStream).await;
5368 assert!(result1.is_ok(), "First merge insert should succeed");
5369
5370 let s2 = RecordBatchStreamAdapter::new(
5372 schema.clone(),
5373 futures::stream::iter(vec![Ok(batch2.clone())]),
5374 );
5375 let result2 = b2.execute(Box::pin(s2) as SendableRecordBatchStream).await;
5376
5377 assert!(
5379 matches!(result2, Err(crate::Error::TooMuchWriteContention { .. })),
5380 "Expected TooMuchWriteContention (retryable conflict exhausted), got: {:?}",
5381 result2
5382 );
5383 }
5384
5385 #[tokio::test]
5394 async fn test_concurrent_find_or_create_same_new_key() {
5395 let schema = Arc::new(Schema::new(vec![
5398 Field::new("id", DataType::UInt32, false).with_metadata(
5399 vec![(
5400 "lance-schema:unenforced-primary-key".to_string(),
5401 "true".to_string(),
5402 )]
5403 .into_iter()
5404 .collect(),
5405 ),
5406 Field::new("value", DataType::UInt32, false),
5407 ]));
5408 let initial = RecordBatch::try_new(
5410 schema.clone(),
5411 vec![
5412 Arc::new(UInt32Array::from(vec![0, 1, 2, 3])),
5413 Arc::new(UInt32Array::from(vec![0, 0, 0, 0])),
5414 ],
5415 )
5416 .unwrap();
5417
5418 let dataset = InsertBuilder::new("memory://")
5419 .execute(vec![initial])
5420 .await
5421 .unwrap();
5422 let dataset = Arc::new(dataset);
5423
5424 let batch1 = RecordBatch::try_new(
5426 schema.clone(),
5427 vec![
5428 Arc::new(UInt32Array::from(vec![100])),
5429 Arc::new(UInt32Array::from(vec![1])),
5430 ],
5431 )
5432 .unwrap();
5433 let batch2 = RecordBatch::try_new(
5434 schema.clone(),
5435 vec![
5436 Arc::new(UInt32Array::from(vec![100])),
5437 Arc::new(UInt32Array::from(vec![2])),
5438 ],
5439 )
5440 .unwrap();
5441
5442 let b2 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
5446 .unwrap()
5447 .when_matched(WhenMatched::DoNothing)
5448 .when_not_matched(WhenNotMatched::InsertAll)
5449 .conflict_retries(0)
5450 .try_build()
5451 .unwrap();
5452
5453 let s1 = RecordBatchStreamAdapter::new(
5455 schema.clone(),
5456 futures::stream::iter(vec![Ok(batch1.clone())]),
5457 );
5458 let b1 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
5459 .unwrap()
5460 .when_matched(WhenMatched::DoNothing)
5461 .when_not_matched(WhenNotMatched::InsertAll)
5462 .try_build()
5463 .unwrap();
5464 let result1 = b1.execute(Box::pin(s1) as SendableRecordBatchStream).await;
5465 assert!(result1.is_ok(), "First find-or-create should succeed");
5466
5467 let s2 = RecordBatchStreamAdapter::new(
5469 schema.clone(),
5470 futures::stream::iter(vec![Ok(batch2.clone())]),
5471 );
5472 let result2 = b2.execute(Box::pin(s2) as SendableRecordBatchStream).await;
5473
5474 assert!(
5475 matches!(result2, Err(crate::Error::TooMuchWriteContention { .. })),
5476 "Expected TooMuchWriteContention (bloom-filter conflict) for find-or-create, got: {:?}",
5477 result2
5478 );
5479 }
5480
5481 #[test]
5482 fn test_concurrent_insert_different_new_list_key() {
5483 let tags_field = Field::new(
5485 "tags",
5486 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
5487 false,
5488 );
5489 let schema = Arc::new(Schema::new(vec![tags_field]));
5490
5491 let mut builder = ListBuilder::new(StringBuilder::new());
5493 builder.append_value(["a", "b"].iter().copied().map(Some));
5494 let tags_array1 = builder.finish();
5495 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(tags_array1)]).unwrap();
5496
5497 let mut builder = ListBuilder::new(StringBuilder::new());
5498 builder.append_value(["c", "d"].iter().copied().map(Some));
5499 let tags_array2 = builder.finish();
5500 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(tags_array2)]).unwrap();
5501
5502 let field_ids = vec![0_i32];
5504 let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
5505 let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
5506
5507 let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("tags")])
5508 .expect("first batch should produce key");
5509 let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("tags")])
5510 .expect("second batch should produce key");
5511
5512 builder1.insert(key1).unwrap();
5513 builder2.insert(key2).unwrap();
5514 let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
5515 let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
5516
5517 let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
5518 assert!(
5519 !has_intersection,
5520 "Expected bloom filters not intersect for different list(string) keys",
5521 );
5522 assert!(
5523 !might_be_fp,
5524 "Bloom filter intersection should be definitively not conflict",
5525 );
5526 }
5527
5528 #[test]
5529 fn test_concurrent_insert_same_new_list_key() {
5530 let tags_field = Field::new(
5532 "tags",
5533 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
5534 false,
5535 );
5536 let schema = Arc::new(Schema::new(vec![tags_field]));
5537
5538 let mut builder = ListBuilder::new(StringBuilder::new());
5540 builder.append_value(["a", "b"].iter().copied().map(Some));
5541 let tags_array1 = builder.finish();
5542 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(tags_array1)]).unwrap();
5543
5544 let mut builder = ListBuilder::new(StringBuilder::new());
5545 builder.append_value(["a", "b"].iter().copied().map(Some));
5546 let tags_array2 = builder.finish();
5547 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(tags_array2)]).unwrap();
5548
5549 let field_ids = vec![0_i32];
5551 let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
5552 let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
5553
5554 let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("tags")])
5555 .expect("first batch should produce key");
5556 let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("tags")])
5557 .expect("second batch should produce key");
5558
5559 builder1.insert(key1).unwrap();
5560 builder2.insert(key2).unwrap();
5561 let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
5562 let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
5563
5564 let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
5565 assert!(
5566 has_intersection,
5567 "Expected bloom filters to intersect for identical list(string) keys",
5568 );
5569 assert!(
5570 might_be_fp,
5571 "Bloom filter intersection should be treated as potential conflict",
5572 );
5573 }
5574
5575 #[test]
5576 fn test_concurrent_insert_same_new_nested_list_key() {
5577 let nested_tags = make_nested_array(&[["a", "b"].as_slice(), ["c"].as_slice()]);
5579 let tags_field = Field::new("tags", nested_tags.data_type().clone(), false);
5580 let nested_tags2 = make_nested_array(&[["a", "b"].as_slice(), ["c"].as_slice()]);
5581
5582 let schema = Arc::new(Schema::new(vec![tags_field]));
5583 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(nested_tags)]).unwrap();
5584 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(nested_tags2)]).unwrap();
5585
5586 let field_ids = vec![0_i32];
5588 let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
5589 let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
5590
5591 let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("tags")])
5592 .expect("first batch should produce key");
5593 let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("tags")])
5594 .expect("second batch should produce key");
5595
5596 builder1.insert(key1).unwrap();
5597 builder2.insert(key2).unwrap();
5598 let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
5599 let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
5600
5601 let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
5602 assert!(
5603 has_intersection,
5604 "Expected bloom filters to intersect for identical nested list(list(string)) keys",
5605 );
5606 assert!(
5607 might_be_fp,
5608 "Bloom filter intersection should be treated as potential conflict",
5609 );
5610 }
5611
5612 #[test]
5613 fn test_concurrent_insert_different_new_struct_key() {
5614 let user_field = Field::new(
5615 "user",
5616 DataType::Struct(
5617 vec![
5618 Field::new("first", DataType::Utf8, false),
5619 Field::new("last", DataType::Utf8, false),
5620 ]
5621 .into(),
5622 ),
5623 false,
5624 );
5625 let schema = Arc::new(Schema::new(vec![user_field]));
5626
5627 let struct_array1 = make_struct_array_first_last_name(vec!["alice"], vec!["smith"]);
5629 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(struct_array1)]).unwrap();
5630
5631 let struct_array2 = make_struct_array_first_last_name(vec!["bob"], vec!["jones"]);
5632 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(struct_array2)]).unwrap();
5633
5634 let field_ids = vec![0_i32];
5636 let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
5637 let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
5638
5639 let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("user")])
5640 .expect("first batch should produce key");
5641 let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("user")])
5642 .expect("second batch should produce key");
5643
5644 builder1.insert(key1).unwrap();
5645 builder2.insert(key2).unwrap();
5646 let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
5647 let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
5648
5649 let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
5650 assert!(
5651 !has_intersection,
5652 "Expected bloom filters not intersect for different struct keys",
5653 );
5654 assert!(
5655 !might_be_fp,
5656 "Bloom filter intersection should be definitively not conflict",
5657 );
5658 }
5659
5660 #[test]
5661 fn test_concurrent_insert_same_new_struct_key() {
5662 let user_field = Field::new(
5663 "user",
5664 DataType::Struct(
5665 vec![
5666 Field::new("first", DataType::Utf8, false),
5667 Field::new("last", DataType::Utf8, false),
5668 ]
5669 .into(),
5670 ),
5671 false,
5672 );
5673 let schema = Arc::new(Schema::new(vec![user_field]));
5674
5675 let struct_array1 = make_struct_array_first_last_name(vec!["alice"], vec!["smith"]);
5677 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(struct_array1)]).unwrap();
5678
5679 let struct_array2 = make_struct_array_first_last_name(vec!["alice"], vec!["smith"]);
5680 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(struct_array2)]).unwrap();
5681
5682 let field_ids = vec![0_i32];
5684 let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
5685 let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
5686
5687 let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("user")])
5688 .expect("first batch should produce key");
5689 let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("user")])
5690 .expect("second batch should produce key");
5691
5692 builder1.insert(key1).unwrap();
5693 builder2.insert(key2).unwrap();
5694 let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
5695 let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
5696
5697 let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
5698 assert!(
5699 has_intersection,
5700 "Expected bloom filters to intersect for identical struct keys",
5701 );
5702 assert!(
5703 might_be_fp,
5704 "Bloom filter intersection should be treated as potential conflict",
5705 );
5706 }
5707
5708 #[test]
5709 fn test_concurrent_insert_same_new_nested_struct_key() {
5710 let outer_struct = make_nested_struct_array_city_zip("seattle", 98101);
5712 let user_field = Field::new("user", outer_struct.data_type().clone(), false);
5713 let schema = Arc::new(Schema::new(vec![user_field]));
5714
5715 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(outer_struct)]).unwrap();
5716
5717 let outer_struct2 = make_nested_struct_array_city_zip("seattle", 98101);
5718 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(outer_struct2)]).unwrap();
5719
5720 let field_ids = vec![0_i32];
5722 let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone());
5723 let mut builder2 = KeyExistenceFilterBuilder::new(field_ids);
5724
5725 let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("user")])
5726 .expect("first batch should produce key");
5727 let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("user")])
5728 .expect("second batch should produce key");
5729
5730 builder1.insert(key1).unwrap();
5731 builder2.insert(key2).unwrap();
5732 let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1);
5733 let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2);
5734
5735 let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap();
5736 assert!(
5737 has_intersection,
5738 "Expected bloom filters to intersect for identical nested struct keys",
5739 );
5740 assert!(
5741 might_be_fp,
5742 "Bloom filter intersection should be treated as potential conflict",
5743 );
5744 }
5745
5746 #[tokio::test]
5748 async fn test_merge_insert_struct_key_upsert() {
5749 let user_field = Field::new(
5750 "user",
5751 DataType::Struct(
5752 vec![
5753 Field::new("first", DataType::Utf8, false),
5754 Field::new("last", DataType::Utf8, false),
5755 ]
5756 .into(),
5757 ),
5758 false,
5759 );
5760 let schema = Arc::new(Schema::new(vec![
5761 user_field,
5762 Field::new("value", DataType::UInt32, false),
5763 ]));
5764
5765 let user_array = make_struct_array_first_last_name(
5770 vec!["alice", "bob", "carla"],
5771 vec!["smith", "jones", "doe"],
5772 );
5773 let values = UInt32Array::from(vec![1, 1, 1]);
5774 let initial_batch =
5775 RecordBatch::try_new(schema.clone(), vec![Arc::new(user_array), Arc::new(values)])
5776 .unwrap();
5777
5778 let test_uri = "memory://test_merge_insert_struct_key.lance";
5779 let dataset = Dataset::write(
5780 RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone()),
5781 test_uri,
5782 None,
5783 )
5784 .await
5785 .unwrap();
5786 let dataset = Arc::new(dataset);
5787
5788 let new_user_array =
5790 make_struct_array_first_last_name(vec!["alice", "david"], vec!["smith", "brown"]);
5791 let new_values = UInt32Array::from(vec![10, 2]);
5792 let new_batch = RecordBatch::try_new(
5793 schema.clone(),
5794 vec![Arc::new(new_user_array), Arc::new(new_values)],
5795 )
5796 .unwrap();
5797
5798 let reader = RecordBatchIterator::new([Ok(new_batch)], schema.clone());
5799 let (merged_ds, stats) = MergeInsertBuilder::try_new(dataset, vec!["user".to_string()])
5800 .unwrap()
5801 .when_matched(WhenMatched::UpdateAll)
5802 .when_not_matched(WhenNotMatched::InsertAll)
5803 .try_build()
5804 .unwrap()
5805 .execute(reader_to_stream(Box::new(reader)))
5806 .await
5807 .unwrap();
5808
5809 assert_eq!(stats.num_updated_rows, 1);
5810 assert_eq!(stats.num_inserted_rows, 1);
5811 assert_eq!(stats.num_deleted_rows, 0);
5812
5813 let result = merged_ds.scan().try_into_batch().await.unwrap();
5814 let user_col = result
5815 .column_by_name("user")
5816 .unwrap()
5817 .as_any()
5818 .downcast_ref::<StructArray>()
5819 .unwrap();
5820 let first = user_col
5821 .column(0)
5822 .as_any()
5823 .downcast_ref::<StringArray>()
5824 .unwrap();
5825 let last = user_col
5826 .column(1)
5827 .as_any()
5828 .downcast_ref::<StringArray>()
5829 .unwrap();
5830 let values = result
5831 .column_by_name("value")
5832 .unwrap()
5833 .as_primitive::<UInt32Type>();
5834
5835 let mut rows = Vec::new();
5836 for i in 0..result.num_rows() {
5837 rows.push((
5838 first.value(i).to_string(),
5839 last.value(i).to_string(),
5840 values.value(i),
5841 ));
5842 }
5843 rows.sort();
5844
5845 assert_eq!(
5846 rows,
5847 vec![
5848 ("alice".to_string(), "smith".to_string(), 10),
5849 ("bob".to_string(), "jones".to_string(), 1),
5850 ("carla".to_string(), "doe".to_string(), 1),
5851 ("david".to_string(), "brown".to_string(), 2),
5852 ],
5853 );
5854 }
5855
5856 fn make_struct_array_first_last_name(first: Vec<&str>, last: Vec<&str>) -> StructArray {
5857 let first = StringArray::from(first);
5858 let last = StringArray::from(last);
5859
5860 StructArray::from(vec![
5861 (
5862 Arc::new(Field::new("first", DataType::Utf8, false)),
5863 Arc::new(first) as Arc<dyn Array>,
5864 ),
5865 (
5866 Arc::new(Field::new("last", DataType::Utf8, false)),
5867 Arc::new(last) as Arc<dyn Array>,
5868 ),
5869 ])
5870 }
5871
5872 fn make_nested_struct_array_city_zip(city: &str, zip: i32) -> StructArray {
5873 let city = StringArray::from(vec![city]);
5874 let zip = Int32Array::from(vec![zip]);
5875
5876 let inner_struct = StructArray::from(vec![
5877 (
5878 Arc::new(Field::new("city", DataType::Utf8, false)),
5879 Arc::new(city) as Arc<dyn Array>,
5880 ),
5881 (
5882 Arc::new(Field::new("zip", DataType::Int32, false)),
5883 Arc::new(zip) as Arc<dyn Array>,
5884 ),
5885 ]);
5886
5887 StructArray::from(vec![(
5888 Arc::new(Field::new(
5889 "address",
5890 inner_struct.data_type().clone(),
5891 false,
5892 )),
5893 Arc::new(inner_struct) as Arc<dyn Array>,
5894 )])
5895 }
5896
5897 fn make_nested_array(inner_lists: &[&[&str]]) -> ListArray {
5898 let mut inner_builder = ListBuilder::new(StringBuilder::new());
5899 for inner in inner_lists {
5900 inner_builder.append_value(inner.iter().map(|s| Some(*s)));
5901 }
5902 let inner_list_array = inner_builder.finish();
5903
5904 let offsets = ScalarBuffer::<i32>::from(vec![0, inner_list_array.len() as i32]);
5905 let offsets = OffsetBuffer::new(offsets);
5906 ListArray::new(
5907 Arc::new(Field::new(
5908 "item",
5909 inner_list_array.data_type().clone(),
5910 inner_list_array.nulls().is_some(),
5911 )),
5912 offsets,
5913 Arc::new(inner_list_array),
5914 None,
5915 )
5916 }
5917
5918 #[tokio::test]
5922 async fn test_merge_insert_conflict_with_update_without_filter() {
5923 use crate::dataset::UpdateBuilder;
5924
5925 let schema = Arc::new(Schema::new(vec![
5927 Field::new("id", DataType::UInt32, false).with_metadata(
5928 vec![(
5929 "lance-schema:unenforced-primary-key".to_string(),
5930 "true".to_string(),
5931 )]
5932 .into_iter()
5933 .collect(),
5934 ),
5935 Field::new("value", DataType::UInt32, false),
5936 ]));
5937 let initial = RecordBatch::try_new(
5938 schema.clone(),
5939 vec![
5940 Arc::new(UInt32Array::from(vec![0, 1, 2, 3])),
5941 Arc::new(UInt32Array::from(vec![0, 0, 0, 0])),
5942 ],
5943 )
5944 .unwrap();
5945
5946 let dataset = InsertBuilder::new("memory://")
5947 .execute(vec![initial])
5948 .await
5949 .unwrap();
5950 let dataset = Arc::new(dataset);
5951
5952 let batch1 = RecordBatch::try_new(
5954 schema.clone(),
5955 vec![
5956 Arc::new(UInt32Array::from(vec![100])),
5957 Arc::new(UInt32Array::from(vec![1])),
5958 ],
5959 )
5960 .unwrap();
5961
5962 let b1 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
5963 .unwrap()
5964 .when_matched(WhenMatched::UpdateAll)
5965 .when_not_matched(WhenNotMatched::InsertAll)
5966 .conflict_retries(0)
5967 .try_build()
5968 .unwrap();
5969
5970 let update_result = UpdateBuilder::new(dataset.clone())
5972 .update_where("id = 0")
5973 .unwrap()
5974 .set("value", "999")
5975 .unwrap()
5976 .build()
5977 .unwrap()
5978 .execute()
5979 .await;
5980 assert!(update_result.is_ok(), "Update should succeed");
5981
5982 let s1 = RecordBatchStreamAdapter::new(
5984 schema.clone(),
5985 futures::stream::iter(vec![Ok(batch1.clone())]),
5986 );
5987 let merge_result = b1.execute(Box::pin(s1) as SendableRecordBatchStream).await;
5988
5989 assert!(
5992 matches!(
5993 merge_result,
5994 Err(crate::Error::TooMuchWriteContention { .. })
5995 ),
5996 "Expected TooMuchWriteContention (retryable conflict exhausted), got: {:?}",
5997 merge_result
5998 );
5999 }
6000
6001 #[tokio::test]
6005 async fn test_merge_insert_conflict_with_append() {
6006 let schema = Arc::new(Schema::new(vec![
6008 Field::new("id", DataType::UInt32, false).with_metadata(
6009 vec![(
6010 "lance-schema:unenforced-primary-key".to_string(),
6011 "true".to_string(),
6012 )]
6013 .into_iter()
6014 .collect(),
6015 ),
6016 Field::new("value", DataType::UInt32, false),
6017 ]));
6018 let initial = RecordBatch::try_new(
6019 schema.clone(),
6020 vec![
6021 Arc::new(UInt32Array::from(vec![0, 1, 2, 3])),
6022 Arc::new(UInt32Array::from(vec![0, 0, 0, 0])),
6023 ],
6024 )
6025 .unwrap();
6026
6027 let dataset = InsertBuilder::new("memory://")
6028 .execute(vec![initial])
6029 .await
6030 .unwrap();
6031 let dataset = Arc::new(dataset);
6032
6033 let batch1 = RecordBatch::try_new(
6035 schema.clone(),
6036 vec![
6037 Arc::new(UInt32Array::from(vec![100])),
6038 Arc::new(UInt32Array::from(vec![1])),
6039 ],
6040 )
6041 .unwrap();
6042
6043 let b1 = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()])
6044 .unwrap()
6045 .when_matched(WhenMatched::UpdateAll)
6046 .when_not_matched(WhenNotMatched::InsertAll)
6047 .conflict_retries(0)
6048 .try_build()
6049 .unwrap();
6050
6051 let append_batch = RecordBatch::try_new(
6053 schema.clone(),
6054 vec![
6055 Arc::new(UInt32Array::from(vec![50])),
6056 Arc::new(UInt32Array::from(vec![2])),
6057 ],
6058 )
6059 .unwrap();
6060 let append_result = InsertBuilder::new(dataset.clone())
6061 .with_params(&WriteParams {
6062 mode: WriteMode::Append,
6063 ..Default::default()
6064 })
6065 .execute(vec![append_batch])
6066 .await;
6067 assert!(append_result.is_ok(), "Append should succeed");
6068
6069 let s1 = RecordBatchStreamAdapter::new(
6071 schema.clone(),
6072 futures::stream::iter(vec![Ok(batch1.clone())]),
6073 );
6074 let merge_result = b1.execute(Box::pin(s1) as SendableRecordBatchStream).await;
6075
6076 assert!(
6079 matches!(
6080 merge_result,
6081 Err(crate::Error::TooMuchWriteContention { .. })
6082 ),
6083 "Expected TooMuchWriteContention (retryable conflict exhausted), got: {:?}",
6084 merge_result
6085 );
6086 }
6087
6088 #[tokio::test]
6089 async fn test_explain_plan() {
6090 let dataset = lance_datagen::gen_batch()
6092 .col("id", lance_datagen::array::step::<Int32Type>())
6093 .col("name", array::cycle_utf8_literals(&["a", "b", "c"]))
6094 .into_ram_dataset(FragmentCount::from(1), FragmentRowCount::from(3))
6095 .await
6096 .unwrap();
6097
6098 let merge_insert_job =
6100 MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()])
6101 .unwrap()
6102 .when_matched(WhenMatched::UpdateAll)
6103 .when_not_matched(WhenNotMatched::InsertAll)
6104 .try_build()
6105 .unwrap();
6106
6107 let plan = merge_insert_job.explain_plan(None, false).await.unwrap();
6109
6110 let expected_pattern = "\
6112MergeInsert: on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep...
6113 CoalescePartitionsExec...
6114 HashJoinExec...
6115 LanceRead...
6116 StreamingTableExec: partition_sizes=1, projection=[id, name]";
6117 assert_string_matches(&plan, expected_pattern).unwrap();
6118
6119 let source_schema = arrow_schema::Schema::from(dataset.schema());
6121 let explicit_plan = merge_insert_job
6122 .explain_plan(Some(&source_schema), false)
6123 .await
6124 .unwrap();
6125 assert_eq!(plan, explicit_plan); let verbose_plan = merge_insert_job.explain_plan(None, true).await.unwrap();
6129 assert!(verbose_plan.contains("MergeInsert"));
6130 assert_string_matches(&verbose_plan, expected_pattern).unwrap();
6132 }
6133
6134 #[tokio::test]
6139 async fn test_explain_plan_find_or_create() {
6140 let dataset = lance_datagen::gen_batch()
6141 .col("id", lance_datagen::array::step::<Int32Type>())
6142 .col("name", array::cycle_utf8_literals(&["a", "b", "c"]))
6143 .into_ram_dataset(FragmentCount::from(1), FragmentRowCount::from(3))
6144 .await
6145 .unwrap();
6146
6147 let merge_insert_job =
6149 MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
6150 .unwrap()
6151 .try_build()
6152 .unwrap();
6153
6154 let plan = merge_insert_job.explain_plan(None, false).await.unwrap();
6155
6156 let expected_pattern = "\
6157MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_not_matched_by_source=Keep...
6158 CoalescePartitionsExec...
6159 HashJoinExec...join_type=Right...
6160 LanceRead...
6161 StreamingTableExec: partition_sizes=1, projection=[id, name]";
6162 assert_string_matches(&plan, expected_pattern).unwrap();
6163 }
6164
6165 #[tokio::test]
6166 async fn test_explain_plan_full_schema_delete_by_source_with_fsl() {
6167 let schema = Arc::new(Schema::new(vec![
6168 Field::new("id", DataType::Int32, false),
6169 Field::new(
6170 "vec",
6171 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
6172 true,
6173 ),
6174 ]));
6175
6176 let dataset_batch = RecordBatch::try_new(
6177 schema.clone(),
6178 vec![
6179 Arc::new(Int32Array::from(vec![1, 2, 3])),
6180 Arc::new(
6181 FixedSizeListArray::try_new_from_values(
6182 Float32Array::from(vec![
6183 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3,
6184 ]),
6185 4,
6186 )
6187 .unwrap(),
6188 ),
6189 ],
6190 )
6191 .unwrap();
6192
6193 let dataset = Dataset::write(
6194 Box::new(RecordBatchIterator::new(
6195 [Ok(dataset_batch)],
6196 schema.clone(),
6197 )),
6198 "memory://test_explain_plan_full_schema_delete_by_source_with_fsl",
6199 None,
6200 )
6201 .await
6202 .unwrap();
6203
6204 let merge_insert_job =
6205 MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
6206 .unwrap()
6207 .when_matched(WhenMatched::UpdateAll)
6208 .when_not_matched(WhenNotMatched::InsertAll)
6209 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
6210 .use_index(false)
6211 .try_build()
6212 .unwrap();
6213
6214 let plan = merge_insert_job.explain_plan(None, false).await.unwrap();
6215 assert!(plan.contains("HashJoinExec"));
6216 assert!(plan.contains("join_type=Full"));
6217 assert!(plan.contains("projection=[_rowid"));
6218 assert!(
6219 plan.contains("LanceRead: uri=") && plan.contains("projection=[id]"),
6220 "target-side scan should prune the FSL payload from the join build side: {plan}"
6221 );
6222 assert!(
6223 !plan.contains("LanceRead: uri=test_explain_plan_full_schema_delete_by_source_with_fsl/data, projection=[id, vec]"),
6224 "target-side scan should not include the FSL payload in the join build side: {plan}"
6225 );
6226 }
6227
6228 #[tokio::test]
6229 async fn test_explain_plan_full_schema_delete_by_source_with_fsl_and_scalar_index() {
6230 let schema = Arc::new(Schema::new(vec![
6231 Field::new("id", DataType::Int32, false),
6232 Field::new(
6233 "vec",
6234 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
6235 true,
6236 ),
6237 ]));
6238
6239 let dataset_batch = RecordBatch::try_new(
6240 schema.clone(),
6241 vec![
6242 Arc::new(Int32Array::from(vec![1, 2, 3])),
6243 Arc::new(
6244 FixedSizeListArray::try_new_from_values(
6245 Float32Array::from(vec![
6246 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3,
6247 ]),
6248 4,
6249 )
6250 .unwrap(),
6251 ),
6252 ],
6253 )
6254 .unwrap();
6255
6256 let mut dataset = Dataset::write(
6257 Box::new(RecordBatchIterator::new(
6258 [Ok(dataset_batch)],
6259 schema.clone(),
6260 )),
6261 "memory://test_explain_plan_full_schema_delete_by_source_with_fsl_and_scalar_index",
6262 None,
6263 )
6264 .await
6265 .unwrap();
6266
6267 let scalar_params = ScalarIndexParams::default();
6268 dataset
6269 .create_index(&["id"], IndexType::Scalar, None, &scalar_params, false)
6270 .await
6271 .unwrap();
6272
6273 let merge_insert_job =
6274 MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
6275 .unwrap()
6276 .when_matched(WhenMatched::UpdateAll)
6277 .when_not_matched(WhenNotMatched::InsertAll)
6278 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
6279 .try_build()
6280 .unwrap();
6281
6282 let plan = merge_insert_job.explain_plan(None, false).await.unwrap();
6283 assert!(plan.contains("HashJoinExec"));
6284 assert!(plan.contains("join_type=Full"));
6285 assert!(plan.contains("projection=[_rowid"));
6286 assert!(
6287 plan.contains("LanceRead: uri=") && plan.contains("projection=[id]"),
6288 "target-side scan should prune the FSL payload from the join build side even when a scalar index exists: {plan}"
6289 );
6290 assert!(
6291 !plan.contains(
6292 "LanceRead: uri=test_explain_plan_full_schema_delete_by_source_with_fsl_and_scalar_index/data, projection=[id, vec]"
6293 ),
6294 "target-side scan should not include the FSL payload in the join build side: {plan}"
6295 );
6296 }
6297
6298 #[tokio::test]
6299 async fn test_merge_insert_full_schema_delete_by_source_with_fsl() {
6300 let schema = Arc::new(Schema::new(vec![
6301 Field::new("id", DataType::Int32, false),
6302 Field::new(
6303 "vec",
6304 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
6305 true,
6306 ),
6307 ]));
6308
6309 let dataset_batch = RecordBatch::try_new(
6310 schema.clone(),
6311 vec![
6312 Arc::new(Int32Array::from(vec![1, 2, 3])),
6313 Arc::new(
6314 FixedSizeListArray::try_new_from_values(
6315 Float32Array::from(vec![
6316 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3,
6317 ]),
6318 4,
6319 )
6320 .unwrap(),
6321 ),
6322 ],
6323 )
6324 .unwrap();
6325
6326 let dataset = Dataset::write(
6327 Box::new(RecordBatchIterator::new(
6328 [Ok(dataset_batch)],
6329 schema.clone(),
6330 )),
6331 "memory://test_merge_insert_full_schema_delete_by_source_with_fsl",
6332 None,
6333 )
6334 .await
6335 .unwrap();
6336
6337 let source_batch = RecordBatch::try_new(
6338 schema.clone(),
6339 vec![
6340 Arc::new(Int32Array::from(vec![2, 4])),
6341 Arc::new(
6342 FixedSizeListArray::try_new_from_values(
6343 Float32Array::from(vec![20.0, 20.1, 20.2, 20.3, 40.0, 40.1, 40.2, 40.3]),
6344 4,
6345 )
6346 .unwrap(),
6347 ),
6348 ],
6349 )
6350 .unwrap();
6351
6352 let (merged_dataset, stats) =
6353 MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
6354 .unwrap()
6355 .when_matched(WhenMatched::UpdateAll)
6356 .when_not_matched(WhenNotMatched::InsertAll)
6357 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
6358 .try_build()
6359 .unwrap()
6360 .execute_reader(Box::new(RecordBatchIterator::new(
6361 [Ok(source_batch)],
6362 schema.clone(),
6363 )))
6364 .await
6365 .unwrap();
6366
6367 assert_eq!(stats.num_deleted_rows, 2);
6368 assert_eq!(stats.num_updated_rows, 1);
6369 assert_eq!(stats.num_inserted_rows, 1);
6370
6371 let merged = merged_dataset.scan().try_into_batch().await.unwrap();
6372 let ids = merged["id"].as_primitive::<Int32Type>().values().to_vec();
6373 assert_eq!(ids, vec![2, 4]);
6374
6375 let vecs = merged["vec"].as_fixed_size_list();
6376 let actual = vecs
6377 .values()
6378 .as_primitive::<Float32Type>()
6379 .values()
6380 .to_vec();
6381 assert_eq!(actual, vec![20.0, 20.1, 20.2, 20.3, 40.0, 40.1, 40.2, 40.3]);
6382 }
6383
6384 #[tokio::test]
6385 async fn test_merge_insert_full_schema_delete_by_source_with_fsl_and_scalar_index() {
6386 let schema = Arc::new(Schema::new(vec![
6387 Field::new("id", DataType::Int32, false),
6388 Field::new(
6389 "vec",
6390 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
6391 true,
6392 ),
6393 ]));
6394
6395 let dataset_batch = RecordBatch::try_new(
6396 schema.clone(),
6397 vec![
6398 Arc::new(Int32Array::from(vec![1, 2, 3])),
6399 Arc::new(
6400 FixedSizeListArray::try_new_from_values(
6401 Float32Array::from(vec![
6402 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3,
6403 ]),
6404 4,
6405 )
6406 .unwrap(),
6407 ),
6408 ],
6409 )
6410 .unwrap();
6411
6412 let mut dataset = Dataset::write(
6413 Box::new(RecordBatchIterator::new(
6414 [Ok(dataset_batch)],
6415 schema.clone(),
6416 )),
6417 "memory://test_merge_insert_full_schema_delete_by_source_with_fsl_and_scalar_index",
6418 None,
6419 )
6420 .await
6421 .unwrap();
6422
6423 let scalar_params = ScalarIndexParams::default();
6424 dataset
6425 .create_index(&["id"], IndexType::Scalar, None, &scalar_params, false)
6426 .await
6427 .unwrap();
6428
6429 let source_batch = RecordBatch::try_new(
6430 schema.clone(),
6431 vec![
6432 Arc::new(Int32Array::from(vec![2, 4])),
6433 Arc::new(
6434 FixedSizeListArray::try_new_from_values(
6435 Float32Array::from(vec![20.0, 20.1, 20.2, 20.3, 40.0, 40.1, 40.2, 40.3]),
6436 4,
6437 )
6438 .unwrap(),
6439 ),
6440 ],
6441 )
6442 .unwrap();
6443
6444 let (merged_dataset, stats) =
6445 MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
6446 .unwrap()
6447 .when_matched(WhenMatched::UpdateAll)
6448 .when_not_matched(WhenNotMatched::InsertAll)
6449 .when_not_matched_by_source(WhenNotMatchedBySource::Delete)
6450 .try_build()
6451 .unwrap()
6452 .execute_reader(Box::new(RecordBatchIterator::new(
6453 [Ok(source_batch)],
6454 schema.clone(),
6455 )))
6456 .await
6457 .unwrap();
6458
6459 assert_eq!(stats.num_deleted_rows, 2);
6460 assert_eq!(stats.num_updated_rows, 1);
6461 assert_eq!(stats.num_inserted_rows, 1);
6462
6463 let merged = merged_dataset.scan().try_into_batch().await.unwrap();
6464 let ids = merged["id"].as_primitive::<Int32Type>().values().to_vec();
6465 assert_eq!(ids, vec![2, 4]);
6466
6467 let vecs = merged["vec"].as_fixed_size_list();
6468 let actual = vecs
6469 .values()
6470 .as_primitive::<Float32Type>()
6471 .values()
6472 .to_vec();
6473 assert_eq!(actual, vec![20.0, 20.1, 20.2, 20.3, 40.0, 40.1, 40.2, 40.3]);
6474 }
6475
6476 #[tokio::test]
6477 async fn test_analyze_plan() {
6478 let mut dataset = lance_datagen::gen_batch()
6480 .col("id", lance_datagen::array::step::<Int32Type>())
6481 .col("name", array::cycle_utf8_literals(&["a", "b", "c"]))
6482 .into_ram_dataset(FragmentCount::from(1), FragmentRowCount::from(3))
6483 .await
6484 .unwrap();
6485
6486 let original_version = dataset.version().version;
6488
6489 let merge_insert_job =
6491 MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()])
6492 .unwrap()
6493 .when_matched(WhenMatched::UpdateAll)
6494 .when_not_matched(WhenNotMatched::InsertAll)
6495 .try_build()
6496 .unwrap();
6497
6498 let schema = Arc::new(arrow_schema::Schema::from(dataset.schema()));
6500 let source_batch = RecordBatch::try_new(
6501 schema.clone(),
6502 vec![
6503 Arc::new(Int32Array::from(vec![1, 4])), Arc::new(StringArray::from(vec!["updated_a", "d"])),
6505 ],
6506 )
6507 .unwrap();
6508
6509 let source_stream = RecordBatchStreamAdapter::new(
6510 schema,
6511 futures::stream::once(async { Ok(source_batch) }).boxed(),
6512 );
6513
6514 let mut analysis = String::from("[");
6518 analysis.push_str(
6519 &merge_insert_job
6520 .analyze_plan(Box::pin(source_stream))
6521 .await
6522 .unwrap(),
6523 );
6524 analysis.push_str(&String::from("]"));
6525
6526 assert!(analysis.contains("MergeInsert"));
6528 assert!(analysis.contains("metrics"));
6529 assert!(analysis.contains("bytes_written"));
6533 assert!(analysis.contains("num_files_written"));
6534
6535 dataset.checkout_latest().await.unwrap();
6538 assert_eq!(
6539 dataset.version().version,
6540 original_version,
6541 "analyze_plan should not create a new dataset version"
6542 );
6543
6544 let expected_pattern = "[...MergeInsert: elapsed=..., on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep, metrics=...bytes_written=...num_deleted_rows=0, num_files_written=...num_inserted_rows=1, num_skipped_duplicates=0, num_updated_rows=1]
6546 ...
6547 StreamingTableExec: partition_sizes=1, projection=[id, name], metrics=[]...]";
6548 assert_string_matches(&analysis, expected_pattern).unwrap();
6549 assert!(analysis.contains("bytes_written"));
6550 assert!(analysis.contains("num_files_written"));
6551 assert!(analysis.contains("elapsed_compute"));
6552 }
6553
6554 #[tokio::test]
6555 async fn test_merge_insert_with_action_column() {
6556 let initial_data = RecordBatch::try_new(
6561 Arc::new(arrow_schema::Schema::new(vec![
6562 arrow_schema::Field::new("id", arrow_schema::DataType::Int32, false),
6563 arrow_schema::Field::new("action", arrow_schema::DataType::Utf8, true),
6564 arrow_schema::Field::new("value", arrow_schema::DataType::Int32, true),
6565 ])),
6566 vec![
6567 Arc::new(Int32Array::from(vec![1, 2, 3])),
6568 Arc::new(StringArray::from(vec!["create", "update", "delete"])),
6569 Arc::new(Int32Array::from(vec![10, 20, 30])),
6570 ],
6571 )
6572 .unwrap();
6573
6574 let tempdir = TempStrDir::default();
6575 let dataset = Dataset::write(
6576 RecordBatchIterator::new(vec![Ok(initial_data.clone())], initial_data.schema()),
6577 &tempdir,
6578 None,
6579 )
6580 .await
6581 .unwrap();
6582
6583 let new_data = RecordBatch::try_new(
6585 Arc::new(arrow_schema::Schema::new(vec![
6586 arrow_schema::Field::new("id", arrow_schema::DataType::Int32, false),
6587 arrow_schema::Field::new("action", arrow_schema::DataType::Utf8, true),
6588 arrow_schema::Field::new("value", arrow_schema::DataType::Int32, true),
6589 ])),
6590 vec![
6591 Arc::new(Int32Array::from(vec![2, 4])),
6592 Arc::new(StringArray::from(vec!["modify", "insert"])),
6593 Arc::new(Int32Array::from(vec![25, 40])),
6594 ],
6595 )
6596 .unwrap();
6597
6598 let merge_insert_job =
6600 MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()])
6601 .unwrap()
6602 .when_matched(WhenMatched::UpdateAll)
6603 .when_not_matched(WhenNotMatched::InsertAll)
6604 .try_build()
6605 .unwrap();
6606
6607 let new_reader = Box::new(RecordBatchIterator::new(
6608 [Ok(new_data.clone())],
6609 new_data.schema(),
6610 ));
6611 let new_stream = reader_to_stream(new_reader);
6612
6613 let (merged_dataset, _) = merge_insert_job.execute(new_stream).await.unwrap();
6614
6615 let result_batches = merged_dataset
6617 .scan()
6618 .try_into_stream()
6619 .await
6620 .unwrap()
6621 .try_collect::<Vec<_>>()
6622 .await
6623 .unwrap();
6624
6625 let result_batch = concat_batches(&result_batches[0].schema(), &result_batches).unwrap();
6626
6627 assert_eq!(result_batch.num_rows(), 4);
6629
6630 let id_col = result_batch
6632 .column(0)
6633 .as_any()
6634 .downcast_ref::<Int32Array>()
6635 .unwrap();
6636 let action_col = result_batch
6637 .column(1)
6638 .as_any()
6639 .downcast_ref::<StringArray>()
6640 .unwrap();
6641 let value_col = result_batch
6642 .column(2)
6643 .as_any()
6644 .downcast_ref::<Int32Array>()
6645 .unwrap();
6646
6647 for i in 0..result_batch.num_rows() {
6649 match id_col.value(i) {
6650 1 => {
6651 assert_eq!(action_col.value(i), "create");
6652 assert_eq!(value_col.value(i), 10);
6653 }
6654 2 => {
6655 assert_eq!(action_col.value(i), "modify"); assert_eq!(value_col.value(i), 25); }
6658 3 => {
6659 assert_eq!(action_col.value(i), "delete");
6660 assert_eq!(value_col.value(i), 30);
6661 }
6662 4 => {
6663 assert_eq!(action_col.value(i), "insert"); assert_eq!(value_col.value(i), 40); }
6666 _ => panic!("Unexpected id: {}", id_col.value(i)),
6667 }
6668 }
6669 }
6670
6671 #[tokio::test]
6672 #[rstest::rstest]
6673 async fn test_duplicate_rowid_detection(
6674 #[values(false, true)] is_full_schema: bool,
6675 #[values(true, false)] enable_stable_row_ids: bool,
6676 #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1, LanceFileVersion::V2_2)]
6677 data_storage_version: LanceFileVersion,
6678 ) {
6679 let test_uri = "memory://test_duplicate_rowid_multi_fragment.lance";
6680
6681 let dataset = lance_datagen::gen_batch()
6683 .col("key", array::step_custom::<UInt32Type>(1, 1))
6684 .col("value", array::step_custom::<UInt32Type>(10, 10))
6685 .into_dataset_with_params(
6686 test_uri,
6687 FragmentCount(3),
6688 FragmentRowCount(4),
6689 Some(WriteParams {
6690 max_rows_per_file: 4,
6691 enable_stable_row_ids,
6692 data_storage_version: Some(data_storage_version),
6693 ..Default::default()
6694 }),
6695 )
6696 .await
6697 .unwrap();
6698
6699 assert_eq!(dataset.get_fragments().len(), 3, "Should have 3 fragments");
6700
6701 let schema = Arc::new(Schema::new(vec![
6702 Field::new("key", DataType::UInt32, is_full_schema),
6703 Field::new("value", DataType::UInt32, is_full_schema),
6704 ]));
6705
6706 let source_batch = RecordBatch::try_new(
6707 schema.clone(),
6708 vec![
6709 Arc::new(UInt32Array::from(vec![2, 2, 6, 6, 10, 10, 15])),
6710 Arc::new(UInt32Array::from(vec![100, 200, 300, 400, 500, 600, 700])),
6711 ],
6712 )
6713 .unwrap();
6714
6715 let job = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()])
6716 .unwrap()
6717 .when_matched(WhenMatched::UpdateAll)
6718 .try_build()
6719 .unwrap();
6720
6721 let reader = Box::new(RecordBatchIterator::new([Ok(source_batch)], schema.clone()));
6722 let stream = reader_to_stream(reader);
6723
6724 let result = job.execute(stream).await;
6725
6726 assert!(
6727 result.is_err(),
6728 "Expected merge insert to fail due to duplicate rows on key column."
6729 );
6730
6731 assert!(
6732 matches!(&result, &Err(Error::InvalidInput { ref source, .. }) if source.to_string().contains("Ambiguous merge insert") && source.to_string().contains("multiple source rows")),
6733 "Expected error to be InvalidInput with message about ambiguous merge insert and multiple source rows, got: {:?}",
6734 result
6735 );
6736 }
6737
6738 #[tokio::test]
6739 #[rstest::rstest]
6740 async fn test_source_dedupe_behavior_first_seen(
6741 #[values(false, true)] is_full_schema: bool,
6742 #[values(true, false)] enable_stable_row_ids: bool,
6743 #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1, LanceFileVersion::V2_2)]
6744 data_storage_version: LanceFileVersion,
6745 ) {
6746 let test_uri = format!(
6747 "memory://test_dedupe_first_seen_{}_{}.lance",
6748 is_full_schema, enable_stable_row_ids
6749 );
6750
6751 let dataset = lance_datagen::gen_batch()
6753 .col("key", array::step_custom::<UInt32Type>(1, 1))
6754 .col("value", array::step_custom::<UInt32Type>(10, 10))
6755 .into_dataset_with_params(
6756 &test_uri,
6757 FragmentCount(1),
6758 FragmentRowCount(4),
6759 Some(WriteParams {
6760 max_rows_per_file: 4,
6761 enable_stable_row_ids,
6762 data_storage_version: Some(data_storage_version),
6763 ..Default::default()
6764 }),
6765 )
6766 .await
6767 .unwrap();
6768
6769 let initial_data: Vec<(u32, u32)> = dataset
6771 .scan()
6772 .try_into_batch()
6773 .await
6774 .unwrap()
6775 .columns()
6776 .iter()
6777 .map(|c| c.as_primitive::<UInt32Type>().values().to_vec())
6778 .collect::<Vec<_>>()
6779 .into_iter()
6780 .fold(Vec::new(), |mut acc, vals| {
6781 if acc.is_empty() {
6782 acc = vals.into_iter().map(|v| (v, 0)).collect();
6783 } else {
6784 for (i, v) in vals.into_iter().enumerate() {
6785 acc[i].1 = v;
6786 }
6787 }
6788 acc
6789 });
6790 assert_eq!(
6791 initial_data,
6792 vec![(1, 10), (2, 20), (3, 30), (4, 40)],
6793 "Initial data should be correct"
6794 );
6795
6796 let schema = Arc::new(Schema::new(vec![
6797 Field::new("key", DataType::UInt32, is_full_schema),
6798 Field::new("value", DataType::UInt32, is_full_schema),
6799 ]));
6800
6801 let source_batch = RecordBatch::try_new(
6807 schema.clone(),
6808 vec![
6809 Arc::new(UInt32Array::from(vec![2, 2, 2, 3, 3, 5])),
6810 Arc::new(UInt32Array::from(vec![100, 200, 300, 400, 500, 600])),
6811 ],
6812 )
6813 .unwrap();
6814
6815 let job = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()])
6816 .unwrap()
6817 .when_matched(WhenMatched::UpdateAll)
6818 .when_not_matched(WhenNotMatched::InsertAll)
6819 .source_dedupe_behavior(SourceDedupeBehavior::FirstSeen)
6820 .try_build()
6821 .unwrap();
6822
6823 let reader = Box::new(RecordBatchIterator::new([Ok(source_batch)], schema.clone()));
6824 let stream = reader_to_stream(reader);
6825
6826 let (dataset, stats) = job.execute(stream).await.unwrap();
6827
6828 assert_eq!(
6830 stats.num_skipped_duplicates, 3,
6831 "Should have skipped 3 duplicate rows (2 extra for key=2, 1 extra for key=3)"
6832 );
6833 assert_eq!(
6834 stats.num_updated_rows, 2,
6835 "Should have updated 2 rows (key=2 and key=3)"
6836 );
6837 assert_eq!(
6838 stats.num_inserted_rows, 1,
6839 "Should have inserted 1 row (key=5)"
6840 );
6841
6842 let result_batch = dataset.scan().try_into_batch().await.unwrap();
6844 let keys = result_batch.column(0).as_primitive::<UInt32Type>();
6845 let values = result_batch.column(1).as_primitive::<UInt32Type>();
6846
6847 let result_data: std::collections::HashMap<u32, u32> = keys
6848 .values()
6849 .iter()
6850 .zip(values.values().iter())
6851 .map(|(&k, &v)| (k, v))
6852 .collect();
6853
6854 assert_eq!(result_data.len(), 5, "Should have 5 rows total");
6855 assert_eq!(
6856 result_data.get(&1),
6857 Some(&10),
6858 "key=1 should be unchanged (original value)"
6859 );
6860 assert_eq!(
6861 result_data.get(&2),
6862 Some(&100),
6863 "key=2 should have first seen value (100, not 200 or 300)"
6864 );
6865 assert_eq!(
6866 result_data.get(&3),
6867 Some(&400),
6868 "key=3 should have first seen value (400, not 500)"
6869 );
6870 assert_eq!(
6871 result_data.get(&4),
6872 Some(&40),
6873 "key=4 should be unchanged (original value)"
6874 );
6875 assert_eq!(
6876 result_data.get(&5),
6877 Some(&600),
6878 "key=5 should be inserted with value 600"
6879 );
6880 }
6881
6882 #[tokio::test]
6883 async fn test_merge_insert_use_index() {
6884 let data = lance_datagen::gen_batch()
6885 .col("id", lance_datagen::array::step::<Int32Type>())
6886 .col("value", array::step::<UInt32Type>());
6887 let data = data.into_reader_rows(RowCount::from(100), BatchCount::from(1));
6888 let schema = data.schema();
6889 let mut ds = Dataset::write(data, "memory://", None).await.unwrap();
6890
6891 let index_params = ScalarIndexParams::default();
6893 ds.create_index(&["id"], IndexType::Scalar, None, &index_params, false)
6894 .await
6895 .unwrap();
6896
6897 let source_batch = RecordBatch::try_new(
6898 schema.clone(),
6899 vec![
6900 Arc::new(Int32Array::from(vec![1, 2, 101])), Arc::new(UInt32Array::from(vec![999, 999, 999])),
6902 ],
6903 )
6904 .unwrap();
6905
6906 let merge_job_no_index =
6908 MergeInsertBuilder::try_new(Arc::new(ds.clone()), vec!["id".to_string()])
6909 .unwrap()
6910 .when_matched(WhenMatched::UpdateAll)
6911 .when_not_matched(WhenNotMatched::InsertAll)
6912 .use_index(false) .try_build()
6914 .unwrap();
6915
6916 let plan = merge_job_no_index.explain_plan(None, false).await;
6918 assert!(
6919 plan.is_ok(),
6920 "explain_plan should succeed with use_index=false"
6921 );
6922 let plan_str = plan.unwrap();
6923 assert!(plan_str.contains("MergeInsert"));
6924 assert!(plan_str.contains("HashJoinExec")); let merge_job_with_index =
6928 MergeInsertBuilder::try_new(Arc::new(ds.clone()), vec!["id".to_string()])
6929 .unwrap()
6930 .when_matched(WhenMatched::UpdateAll)
6931 .when_not_matched(WhenNotMatched::InsertAll)
6932 .use_index(true) .try_build()
6934 .unwrap();
6935
6936 let plan_result = merge_job_with_index.explain_plan(None, false).await;
6938 assert!(
6939 plan_result.is_err(),
6940 "explain_plan should fail with use_index=true when index exists"
6941 );
6942
6943 match plan_result {
6944 Err(Error::NotSupported { source, .. }) => {
6945 assert!(source.to_string().contains("does not support explain_plan"));
6946 }
6947 _ => panic!("Expected NotSupported error"),
6948 }
6949
6950 let source = Box::new(RecordBatchIterator::new(
6952 vec![Ok(source_batch.clone())],
6953 schema.clone(),
6954 ));
6955 let (result_ds, stats) = merge_job_no_index.execute_reader(source).await.unwrap();
6956 assert_eq!(stats.num_updated_rows, 2);
6957 assert_eq!(stats.num_inserted_rows, 1);
6958
6959 let updated_count = result_ds
6961 .count_rows(Some("value = 999".to_string()))
6962 .await
6963 .unwrap();
6964 assert_eq!(updated_count, 3);
6965 }
6966
6967 #[tokio::test]
6968 async fn test_full_schema_upsert_fragment_bitmap() {
6969 let schema = Arc::new(Schema::new(vec![
6970 Field::new("key", DataType::UInt32, true),
6971 Field::new("value", DataType::UInt32, true),
6972 Field::new(
6973 "vec",
6974 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
6975 true,
6976 ),
6977 ]));
6978
6979 let mut dataset = lance_datagen::gen_batch()
6980 .col("key", array::step_custom::<UInt32Type>(1, 1))
6981 .col("value", array::step_custom::<UInt32Type>(10, 10))
6982 .col(
6983 "vec",
6984 array::cycle_vec(
6985 array::cycle::<Float32Type>(vec![
6986 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
6987 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
6988 ]),
6989 Dimension::from(4),
6990 ),
6991 )
6992 .into_ram_dataset_with_params(
6993 FragmentCount::from(2),
6994 FragmentRowCount::from(3),
6995 Some(WriteParams {
6996 max_rows_per_file: 3,
6997 enable_stable_row_ids: true,
6998 ..Default::default()
6999 }),
7000 )
7001 .await
7002 .unwrap();
7003
7004 let scalar_params = ScalarIndexParams::default();
7005 dataset
7006 .create_index(
7007 &["value"],
7008 IndexType::Scalar,
7009 Some("value_idx".to_string()),
7010 &scalar_params,
7011 true,
7012 )
7013 .await
7014 .unwrap();
7015
7016 let vector_params = VectorIndexParams::ivf_flat(1, MetricType::L2);
7017 dataset
7018 .create_index(
7019 &["vec"],
7020 IndexType::Vector,
7021 Some("vec_idx".to_string()),
7022 &vector_params,
7023 true,
7024 )
7025 .await
7026 .unwrap();
7027
7028 let indices = dataset.load_indices().await.unwrap();
7029 let value_index = indices.iter().find(|idx| idx.name == "value_idx").unwrap();
7030 let vec_index = indices.iter().find(|idx| idx.name == "vec_idx").unwrap();
7031
7032 assert_eq!(
7033 value_index
7034 .fragment_bitmap
7035 .as_ref()
7036 .unwrap()
7037 .iter()
7038 .collect::<Vec<_>>(),
7039 vec![0, 1]
7040 );
7041 assert_eq!(
7042 vec_index
7043 .fragment_bitmap
7044 .as_ref()
7045 .unwrap()
7046 .iter()
7047 .collect::<Vec<_>>(),
7048 vec![0, 1]
7049 );
7050
7051 let upsert_keys = UInt32Array::from(vec![2, 5]);
7053 let upsert_values = UInt32Array::from(vec![200, 500]);
7054 let upsert_vecs = FixedSizeListArray::try_new_from_values(
7055 Float32Array::from(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0]),
7056 4,
7057 )
7058 .unwrap();
7059
7060 let upsert_batch = RecordBatch::try_new(
7061 schema.clone(),
7062 vec![
7063 Arc::new(upsert_keys),
7064 Arc::new(upsert_values),
7065 Arc::new(upsert_vecs),
7066 ],
7067 )
7068 .unwrap();
7069
7070 let upsert_stream = RecordBatchStreamAdapter::new(
7071 schema.clone(),
7072 futures::stream::once(async { Ok(upsert_batch) }).boxed(),
7073 );
7074
7075 let (updated_dataset, _stats) =
7076 MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()])
7077 .unwrap()
7078 .when_matched(WhenMatched::UpdateAll)
7079 .when_not_matched(WhenNotMatched::DoNothing)
7080 .when_not_matched_by_source(WhenNotMatchedBySource::Keep)
7081 .try_build()
7082 .unwrap()
7083 .execute(Box::pin(upsert_stream))
7084 .await
7085 .unwrap();
7086
7087 let fragments = updated_dataset.get_fragments();
7088 assert_eq!(fragments.len(), 3);
7089 }
7090
7091 #[tokio::test]
7092 async fn test_sub_schema_upsert_fragment_bitmap() {
7093 let mut dataset = lance_datagen::gen_batch()
7094 .col("key", array::step_custom::<UInt32Type>(1, 1))
7095 .col("value", array::step_custom::<UInt32Type>(10, 10))
7096 .col(
7097 "vec",
7098 array::cycle_vec(
7099 array::cycle::<Float32Type>(vec![
7100 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
7101 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
7102 ]),
7103 Dimension::from(4),
7104 ),
7105 )
7106 .into_ram_dataset_with_params(
7107 FragmentCount::from(2),
7108 FragmentRowCount::from(3),
7109 Some(WriteParams {
7110 max_rows_per_file: 3,
7111 enable_stable_row_ids: true,
7112 ..Default::default()
7113 }),
7114 )
7115 .await
7116 .unwrap();
7117
7118 let scalar_params = ScalarIndexParams::default();
7119 dataset
7120 .create_index(
7121 &["value"],
7122 IndexType::Scalar,
7123 Some("value_idx".to_string()),
7124 &scalar_params,
7125 true,
7126 )
7127 .await
7128 .unwrap();
7129
7130 let vector_params = VectorIndexParams::ivf_flat(1, MetricType::L2);
7131 dataset
7132 .create_index(
7133 &["vec"],
7134 IndexType::Vector,
7135 Some("vec_idx".to_string()),
7136 &vector_params,
7137 true,
7138 )
7139 .await
7140 .unwrap();
7141
7142 let indices = dataset.load_indices().await.unwrap();
7143 let value_index = indices.iter().find(|idx| idx.name == "value_idx").unwrap();
7144 let vec_index = indices.iter().find(|idx| idx.name == "vec_idx").unwrap();
7145
7146 assert_eq!(
7147 value_index
7148 .fragment_bitmap
7149 .as_ref()
7150 .unwrap()
7151 .iter()
7152 .collect::<Vec<_>>(),
7153 vec![0, 1]
7154 );
7155 assert_eq!(
7156 vec_index
7157 .fragment_bitmap
7158 .as_ref()
7159 .unwrap()
7160 .iter()
7161 .collect::<Vec<_>>(),
7162 vec![0, 1]
7163 );
7164
7165 let sub_schema = Arc::new(Schema::new(vec![
7166 Field::new("key", DataType::UInt32, true),
7167 Field::new(
7168 "vec",
7169 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
7170 true,
7171 ),
7172 ]));
7173
7174 let upsert_keys = UInt32Array::from(vec![2, 5]);
7175 let upsert_vecs = FixedSizeListArray::try_new_from_values(
7176 Float32Array::from(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0]),
7177 4,
7178 )
7179 .unwrap();
7180
7181 let upsert_batch = RecordBatch::try_new(
7182 sub_schema.clone(),
7183 vec![Arc::new(upsert_keys), Arc::new(upsert_vecs)],
7184 )
7185 .unwrap();
7186
7187 let upsert_stream = RecordBatchStreamAdapter::new(
7188 sub_schema.clone(),
7189 futures::stream::once(async { Ok(upsert_batch) }).boxed(),
7190 );
7191
7192 let (updated_dataset, _stats) =
7193 MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()])
7194 .unwrap()
7195 .when_matched(WhenMatched::UpdateAll)
7196 .when_not_matched(WhenNotMatched::DoNothing)
7197 .when_not_matched_by_source(WhenNotMatchedBySource::Keep)
7198 .try_build()
7199 .unwrap()
7200 .execute(Box::pin(upsert_stream))
7201 .await
7202 .unwrap();
7203
7204 let fragments = updated_dataset.get_fragments();
7205 assert_eq!(fragments.len(), 3);
7212
7213 let updated_indices = updated_dataset.load_indices().await.unwrap();
7214 assert_eq!(updated_indices.len(), 2);
7221 let updated_value_index = updated_indices
7222 .iter()
7223 .find(|idx| idx.name == "value_idx")
7224 .unwrap();
7225
7226 let value_bitmap = updated_value_index.fragment_bitmap.as_ref().unwrap();
7231 assert!(value_bitmap.contains(0));
7232 assert!(value_bitmap.contains(1));
7233 }
7234
7235 #[tokio::test]
7236 async fn test_when_matched_fail() {
7237 let dataset = create_test_dataset("memory://test_fail", LanceFileVersion::V2_0, true).await;
7238
7239 let new_data = RecordBatch::try_new(
7241 create_test_schema(),
7242 vec![
7243 Arc::new(UInt32Array::from(vec![1, 2, 10, 11])), Arc::new(UInt32Array::from(vec![100, 200, 1000, 1100])),
7245 Arc::new(StringArray::from(vec!["X", "Y", "Z", "W"])),
7246 ],
7247 )
7248 .unwrap();
7249
7250 let reader = Box::new(RecordBatchIterator::new(
7251 [Ok(new_data.clone())],
7252 new_data.schema(),
7253 ));
7254 let new_stream = reader_to_stream(reader);
7255
7256 let result = MergeInsertBuilder::try_new(dataset.clone(), vec!["key".to_string()])
7257 .unwrap()
7258 .when_matched(WhenMatched::Fail)
7259 .when_not_matched(WhenNotMatched::InsertAll)
7260 .try_build()
7261 .unwrap()
7262 .execute(new_stream)
7263 .await;
7264
7265 match result {
7267 Ok((_dataset, stats)) => {
7268 panic!(
7269 "Expected merge insert to fail, but it succeeded. Stats: {:?}",
7270 stats
7271 );
7272 }
7273 Err(e) => {
7274 let error_msg = e.to_string();
7275 assert!(error_msg.contains("Merge insert failed"));
7276 assert!(error_msg.contains("found matching row"));
7277 }
7278 }
7279
7280 let new_data = RecordBatch::try_new(
7282 create_test_schema(),
7283 vec![
7284 Arc::new(UInt32Array::from(vec![10, 11, 12])), Arc::new(UInt32Array::from(vec![1000, 1100, 1200])),
7286 Arc::new(StringArray::from(vec!["X", "Y", "Z"])),
7287 ],
7288 )
7289 .unwrap();
7290
7291 let reader = Box::new(RecordBatchIterator::new(
7292 [Ok(new_data.clone())],
7293 new_data.schema(),
7294 ));
7295 let new_stream = reader_to_stream(reader);
7296
7297 let (updated_dataset, stats) =
7298 MergeInsertBuilder::try_new(dataset.clone(), vec!["key".to_string()])
7299 .unwrap()
7300 .when_matched(WhenMatched::Fail)
7301 .when_not_matched(WhenNotMatched::InsertAll)
7302 .try_build()
7303 .unwrap()
7304 .execute(new_stream)
7305 .await
7306 .unwrap();
7307
7308 assert_eq!(stats.num_inserted_rows, 3);
7310 assert_eq!(stats.num_updated_rows, 0);
7311 assert_eq!(stats.num_deleted_rows, 0);
7312
7313 let count = updated_dataset
7315 .count_rows(Some("key >= 10".to_string()))
7316 .await
7317 .unwrap();
7318 assert_eq!(count, 3);
7319 }
7320
7321 #[tokio::test]
7329 async fn test_merge_insert_permissive_nullability() {
7330 let non_nullable_schema = Arc::new(Schema::new(vec![
7332 Field::new("id", DataType::Int64, false), Field::new("value", DataType::Int64, false), ]));
7335
7336 let initial_data = RecordBatch::try_new(
7337 non_nullable_schema.clone(),
7338 vec![
7339 Arc::new(Int64Array::from(vec![1, 2, 3])),
7340 Arc::new(Int64Array::from(vec![100, 200, 300])),
7341 ],
7342 )
7343 .unwrap();
7344
7345 let test_uri = "memory://test_nullable_issue_4654";
7346 let dataset = Dataset::write(
7347 RecordBatchIterator::new(vec![Ok(initial_data)], non_nullable_schema.clone()),
7348 test_uri,
7349 None,
7350 )
7351 .await
7352 .unwrap();
7353
7354 let nullable_schema = Arc::new(Schema::new(vec![
7356 Field::new("id", DataType::Int64, true), Field::new("value", DataType::Int64, true), ]));
7359
7360 let new_data = RecordBatch::try_new(
7361 nullable_schema.clone(),
7362 vec![
7363 Arc::new(Int64Array::from(vec![2, 4, 5])), Arc::new(Int64Array::from(vec![999, 400, 500])), ],
7366 )
7367 .unwrap();
7368
7369 let merge_result = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
7371 .unwrap()
7372 .when_matched(WhenMatched::UpdateAll)
7373 .when_not_matched(WhenNotMatched::InsertAll)
7374 .try_build()
7375 .unwrap()
7376 .execute_reader(Box::new(RecordBatchIterator::new(
7377 vec![Ok(new_data.clone())],
7378 nullable_schema.clone(),
7379 )))
7380 .await;
7381
7382 assert!(
7383 merge_result.is_ok(),
7384 "merge_insert() should succeed with nullable fields but no actual nulls. \
7385 This is the same behavior as insert/append. Error: {:?}",
7386 merge_result.err()
7387 );
7388
7389 let (merged_dataset, stats) = merge_result.unwrap();
7391
7392 assert_eq!(stats.num_updated_rows, 1, "Should update 1 row (id=2)");
7394 assert_eq!(
7395 stats.num_inserted_rows, 2,
7396 "Should insert 2 new rows (id=4,5)"
7397 );
7398
7399 let count = merged_dataset.count_rows(None).await.unwrap();
7401 assert_eq!(count, 5, "Should have 5 total rows");
7402
7403 let result = merged_dataset
7405 .scan()
7406 .filter("id = 2")
7407 .unwrap()
7408 .try_into_stream()
7409 .await
7410 .unwrap()
7411 .try_collect::<Vec<_>>()
7412 .await
7413 .unwrap();
7414
7415 let batch = concat_batches(&result[0].schema(), &result).unwrap();
7416 assert_eq!(batch.num_rows(), 1);
7417 let value_array = batch
7418 .column(1)
7419 .as_any()
7420 .downcast_ref::<Int64Array>()
7421 .unwrap();
7422 assert_eq!(
7423 value_array.value(0),
7424 999,
7425 "Value for id=2 should be updated to 999"
7426 );
7427 }
7428
7429 #[tokio::test]
7438 async fn test_merge_insert_null_on_column_inserts() {
7439 let initial_data = record_batch!(
7441 ("id", Int32, [0]),
7442 ("record_type", Utf8, [Option::<&str>::None]),
7443 ("value", Int32, [10])
7444 )
7445 .unwrap();
7446
7447 let dataset = Dataset::write(
7448 RecordBatchIterator::new(vec![Ok(initial_data.clone())], initial_data.schema()),
7449 "memory://test_null_on_column",
7450 None,
7451 )
7452 .await
7453 .unwrap();
7454
7455 let new_data = record_batch!(
7459 ("id", Int32, [Some(2)]),
7460 ("record_type", Utf8, [Option::<&str>::None]),
7461 ("value", Int32, [Some(99)])
7462 )
7463 .unwrap();
7464
7465 let (merged_dataset, stats) = MergeInsertBuilder::try_new(
7466 Arc::new(dataset),
7467 vec!["id".to_string(), "record_type".to_string()],
7468 )
7469 .unwrap()
7470 .when_matched(WhenMatched::UpdateAll)
7471 .when_not_matched(WhenNotMatched::InsertAll)
7472 .try_build()
7473 .unwrap()
7474 .execute_reader(Box::new(RecordBatchIterator::new(
7475 vec![Ok(new_data.clone())],
7476 new_data.schema(),
7477 )))
7478 .await
7479 .unwrap();
7480
7481 assert_eq!(
7483 stats.num_inserted_rows, 1,
7484 "row with NULL ON column should be inserted"
7485 );
7486 assert_eq!(stats.num_updated_rows, 0, "no row should be updated");
7487
7488 let count = merged_dataset.count_rows(None).await.unwrap();
7489 assert_eq!(
7490 count, 2,
7491 "dataset should have the original row plus the newly inserted row"
7492 );
7493 }
7494
7495 #[tokio::test]
7500 async fn test_merge_insert_partial_composite_key_null() {
7501 let initial_data = record_batch!(
7503 ("id", Int32, [Some(1)]),
7504 ("record_type", Utf8, [Some("A")]),
7505 ("value", Int32, [Some(10)])
7506 )
7507 .unwrap();
7508
7509 let dataset = Dataset::write(
7510 RecordBatchIterator::new(vec![Ok(initial_data.clone())], initial_data.schema()),
7511 "memory://test_partial_composite_null",
7512 None,
7513 )
7514 .await
7515 .unwrap();
7516
7517 let new_data = record_batch!(
7521 ("id", Int32, [Some(1)]),
7522 ("record_type", Utf8, [Option::<&str>::None]),
7523 ("value", Int32, [Some(99)])
7524 )
7525 .unwrap();
7526
7527 let (merged_dataset, stats) = MergeInsertBuilder::try_new(
7528 Arc::new(dataset),
7529 vec!["id".to_string(), "record_type".to_string()],
7530 )
7531 .unwrap()
7532 .when_matched(WhenMatched::UpdateAll)
7533 .when_not_matched(WhenNotMatched::InsertAll)
7534 .try_build()
7535 .unwrap()
7536 .execute_reader(Box::new(RecordBatchIterator::new(
7537 vec![Ok(new_data.clone())],
7538 new_data.schema(),
7539 )))
7540 .await
7541 .unwrap();
7542
7543 assert_eq!(
7545 stats.num_inserted_rows, 1,
7546 "row with partial NULL composite key should be inserted"
7547 );
7548 assert_eq!(
7549 stats.num_updated_rows, 0,
7550 "existing (id=1, record_type=A) row must not be updated"
7551 );
7552
7553 let count = merged_dataset.count_rows(None).await.unwrap();
7555 assert_eq!(
7556 count, 2,
7557 "both the original and the new row must be present"
7558 );
7559 }
7560
7561 #[tokio::test]
7566 async fn test_merge_insert_null_single_on_column() {
7567 let initial_data = record_batch!(
7569 ("id", Int32, [Option::<i32>::None]),
7570 ("value", Int32, [Some(1)])
7571 )
7572 .unwrap();
7573
7574 let dataset = Dataset::write(
7575 RecordBatchIterator::new(vec![Ok(initial_data.clone())], initial_data.schema()),
7576 "memory://test_null_single_on_column",
7577 None,
7578 )
7579 .await
7580 .unwrap();
7581
7582 let new_data = record_batch!(
7586 ("id", Int32, [Option::<i32>::None, Some(5)]),
7587 ("value", Int32, [Some(99), Some(50)])
7588 )
7589 .unwrap();
7590
7591 let (merged_dataset, stats) =
7592 MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
7593 .unwrap()
7594 .when_matched(WhenMatched::UpdateAll)
7595 .when_not_matched(WhenNotMatched::InsertAll)
7596 .try_build()
7597 .unwrap()
7598 .execute_reader(Box::new(RecordBatchIterator::new(
7599 vec![Ok(new_data.clone())],
7600 new_data.schema(),
7601 )))
7602 .await
7603 .unwrap();
7604
7605 assert_eq!(
7607 stats.num_inserted_rows, 2,
7608 "both rows with NULL ON column should be inserted"
7609 );
7610 assert_eq!(stats.num_updated_rows, 0);
7611
7612 let count = merged_dataset.count_rows(None).await.unwrap();
7614 assert_eq!(count, 3);
7615 }
7616
7617 #[tokio::test]
7620 async fn test_merge_insert_subschema_invalid_type_error() {
7621 let schema = Arc::new(Schema::new(vec![
7623 Field::new("id", DataType::Int32, false),
7624 Field::new("value", DataType::Float64, true), Field::new("extra", DataType::Utf8, true),
7626 ]));
7627
7628 let initial_data = RecordBatch::try_new(
7629 schema.clone(),
7630 vec![
7631 Arc::new(Int32Array::from(vec![1, 2, 3])),
7632 Arc::new(Float64Array::from(vec![1.1, 2.2, 3.3])),
7633 Arc::new(StringArray::from(vec!["a", "b", "c"])),
7634 ],
7635 )
7636 .unwrap();
7637
7638 let test_uri = "memory://test_issue_3634";
7639 let dataset = Dataset::write(
7640 RecordBatchIterator::new(vec![Ok(initial_data)], schema),
7641 test_uri,
7642 None,
7643 )
7644 .await
7645 .unwrap();
7646
7647 let subschema_with_wrong_type = Arc::new(Schema::new(vec![
7649 Field::new("id", DataType::Int32, false),
7650 Field::new("value", DataType::Int32, true),
7651 ]));
7652
7653 let new_data = RecordBatch::try_new(
7654 subschema_with_wrong_type.clone(),
7655 vec![
7656 Arc::new(Int32Array::from(vec![2, 4])),
7657 Arc::new(Int32Array::from(vec![22, 44])),
7658 ],
7659 )
7660 .unwrap();
7661
7662 let merge_result = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
7664 .unwrap()
7665 .when_matched(WhenMatched::UpdateAll)
7666 .when_not_matched(WhenNotMatched::InsertAll)
7667 .try_build()
7668 .unwrap()
7669 .execute_reader(Box::new(RecordBatchIterator::new(
7670 vec![Ok(new_data)],
7671 subschema_with_wrong_type,
7672 )))
7673 .await;
7674
7675 let err = merge_result.expect_err("Merge insert should have failed but it succeeded.");
7677 assert!(
7678 matches!(err, lance_core::Error::SchemaMismatch { .. }),
7679 "Expected a SchemaMismatch error, but got a different error type: {:?}",
7680 err
7681 );
7682
7683 let error_message = err.to_string();
7684 assert!(
7685 error_message.contains("`value` should have type double but type was int32"),
7686 "Error message should specify the expected (double) and actual (int32) types for 'value', but was: {}",
7687 error_message
7688 );
7689
7690 assert!(
7691 !error_message.contains("missing="),
7692 "Error message should NOT complain about missing fields for a subschema check, but was: {}",
7693 error_message
7694 );
7695 }
7696
7697 #[tokio::test]
7701 async fn test_merge_insert_mixed_case_key() {
7702 let schema = Arc::new(Schema::new(vec![
7704 Field::new("userId", DataType::UInt32, false),
7705 Field::new("value", DataType::UInt32, true),
7706 ]));
7707
7708 let initial_batch = RecordBatch::try_new(
7710 schema.clone(),
7711 vec![
7712 Arc::new(UInt32Array::from(vec![1, 2, 3])),
7713 Arc::new(UInt32Array::from(vec![10, 20, 30])),
7714 ],
7715 )
7716 .unwrap();
7717
7718 let test_uri = "memory://test_mixed_case.lance";
7720 let ds = Dataset::write(
7721 RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone()),
7722 test_uri,
7723 None,
7724 )
7725 .await
7726 .unwrap();
7727
7728 let new_batch = RecordBatch::try_new(
7730 schema.clone(),
7731 vec![
7732 Arc::new(UInt32Array::from(vec![2, 4])),
7733 Arc::new(UInt32Array::from(vec![200, 400])),
7734 ],
7735 )
7736 .unwrap();
7737
7738 let job = MergeInsertBuilder::try_new(Arc::new(ds), vec!["userId".to_string()])
7740 .unwrap()
7741 .when_matched(WhenMatched::UpdateAll)
7742 .try_build()
7743 .unwrap();
7744
7745 let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
7746 let new_stream = reader_to_stream(new_reader);
7747
7748 let (merged_ds, _merge_stats) = job.execute(new_stream).await.unwrap();
7749
7750 let result = merged_ds
7752 .scan()
7753 .try_into_stream()
7754 .await
7755 .unwrap()
7756 .try_collect::<Vec<_>>()
7757 .await
7758 .unwrap();
7759
7760 let result_batch = concat_batches(&schema, &result).unwrap();
7761 assert_eq!(result_batch.num_rows(), 4); let user_ids = result_batch
7765 .column(0)
7766 .as_any()
7767 .downcast_ref::<UInt32Array>()
7768 .unwrap();
7769 let values = result_batch
7770 .column(1)
7771 .as_any()
7772 .downcast_ref::<UInt32Array>()
7773 .unwrap();
7774
7775 for i in 0..result_batch.num_rows() {
7777 if user_ids.value(i) == 2 {
7778 assert_eq!(
7779 values.value(i),
7780 200,
7781 "userId=2 should have been updated to value=200"
7782 );
7783 }
7784 }
7785 }
7786
7787 #[tokio::test]
7790 async fn test_merge_insert_reordered_columns() {
7791 use arrow_array::record_batch;
7792
7793 let initial_data = record_batch!(
7794 ("id", Int32, [1, 2, 3]),
7795 ("value", Float64, [1.1, 2.2, 3.3]),
7796 ("extra", Utf8, ["a", "b", "c"])
7797 )
7798 .unwrap();
7799
7800 let dataset = Dataset::write(
7801 RecordBatchIterator::new(vec![Ok(initial_data.clone())], initial_data.schema()),
7802 "memory://test_issue_5323",
7803 None,
7804 )
7805 .await
7806 .unwrap();
7807
7808 let new_data = record_batch!(
7810 ("extra", Utf8, ["x", "y"]),
7811 ("id", Int32, [2, 4]), ("value", Float64, [22.2, 44.4])
7813 )
7814 .unwrap();
7815
7816 let job = MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()])
7818 .unwrap()
7819 .when_matched(WhenMatched::UpdateAll)
7820 .when_not_matched(WhenNotMatched::InsertAll)
7821 .try_build()
7822 .unwrap();
7823 assert!(
7824 job.can_use_create_plan(&new_data.schema()).await.unwrap(),
7825 "Reordered schema should be able to use fast path"
7826 );
7827
7828 let (merged_dataset, _) =
7830 MergeInsertBuilder::try_new(Arc::new(dataset), vec!["id".to_string()])
7831 .unwrap()
7832 .when_matched(WhenMatched::UpdateAll)
7833 .when_not_matched(WhenNotMatched::InsertAll)
7834 .try_build()
7835 .unwrap()
7836 .execute_reader(Box::new(RecordBatchIterator::new(
7837 vec![Ok(new_data.clone())],
7838 new_data.schema(),
7839 )))
7840 .await
7841 .unwrap();
7842
7843 let result = merged_dataset
7844 .scan()
7845 .order_by(Some(vec![ColumnOrdering::asc_nulls_first(
7846 "id".to_string(),
7847 )]))
7848 .unwrap()
7849 .try_into_batch()
7850 .await
7851 .unwrap();
7852
7853 let expected = record_batch!(
7854 ("id", Int32, [1, 2, 3, 4]),
7855 ("value", Float64, [1.1, 22.2, 3.3, 44.4]),
7856 ("extra", Utf8, ["a", "x", "c", "y"])
7857 )
7858 .unwrap();
7859
7860 assert_eq!(result, expected);
7861 }
7862
7863 #[rstest::rstest]
7867 #[tokio::test]
7868 async fn test_when_matched_delete_full_schema(
7869 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
7870 #[values(true, false)] enable_stable_row_ids: bool,
7871 ) {
7872 let schema = create_test_schema();
7873 let test_uri = "memory://test_delete_full.lance";
7874
7875 let ds = create_test_dataset(test_uri, version, enable_stable_row_ids).await;
7877
7878 let new_batch = RecordBatch::try_new(
7882 schema.clone(),
7883 vec![
7884 Arc::new(UInt32Array::from(vec![4, 5, 6, 7, 8, 9])),
7885 Arc::new(UInt32Array::from(vec![2, 2, 2, 2, 2, 2])),
7886 Arc::new(StringArray::from(vec!["A", "B", "C", "A", "B", "C"])),
7887 ],
7888 )
7889 .unwrap();
7890
7891 let keys = vec!["key".to_string()];
7892
7893 let plan_job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
7897 .unwrap()
7898 .when_matched(WhenMatched::Delete)
7899 .when_not_matched(WhenNotMatched::DoNothing)
7900 .try_build()
7901 .unwrap();
7902 let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
7903 [Ok(new_batch.clone())],
7904 schema.clone(),
7905 )));
7906 let plan = plan_job.create_plan(plan_stream).await.unwrap();
7907 assert_plan_node_equals(
7908 plan,
7909 "DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing
7910 ...
7911 HashJoinExec: ...join_type=Inner...
7912 ...
7913 ...
7914 StreamingTableExec: partition_sizes=1, projection=[key]",
7915 )
7916 .await
7917 .unwrap();
7918 let job = MergeInsertBuilder::try_new(ds.clone(), keys)
7919 .unwrap()
7920 .when_matched(WhenMatched::Delete)
7921 .when_not_matched(WhenNotMatched::DoNothing)
7922 .try_build()
7923 .unwrap();
7924
7925 let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
7926 let new_stream = reader_to_stream(new_reader);
7927
7928 let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
7929
7930 assert_eq!(merge_stats.num_deleted_rows, 3);
7932 assert_eq!(merge_stats.num_inserted_rows, 0);
7933 assert_eq!(merge_stats.num_updated_rows, 0);
7934
7935 let batches = merged_dataset
7937 .scan()
7938 .try_into_stream()
7939 .await
7940 .unwrap()
7941 .try_collect::<Vec<_>>()
7942 .await
7943 .unwrap();
7944
7945 let merged = concat_batches(&schema, &batches).unwrap();
7946 let mut remaining_keys: Vec<u32> = merged
7947 .column(0)
7948 .as_primitive::<UInt32Type>()
7949 .values()
7950 .to_vec();
7951 remaining_keys.sort();
7952 assert_eq!(remaining_keys, vec![1, 2, 3]);
7953 }
7954
7955 #[rstest::rstest]
7958 #[tokio::test]
7959 async fn test_when_matched_delete_id_only(
7960 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
7961 #[values(true, false)] enable_stable_row_ids: bool,
7962 ) {
7963 let test_uri = "memory://test_delete_id_only.lance";
7964
7965 let ds = create_test_dataset(test_uri, version, enable_stable_row_ids).await;
7967 let id_only_schema = Arc::new(Schema::new(vec![Field::new("key", DataType::UInt32, true)]));
7968 let new_batch = RecordBatch::try_new(
7969 id_only_schema.clone(),
7970 vec![Arc::new(UInt32Array::from(vec![2, 4, 6]))], )
7972 .unwrap();
7973
7974 let keys = vec!["key".to_string()];
7975
7976 let plan_job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
7979 .unwrap()
7980 .when_matched(WhenMatched::Delete)
7981 .when_not_matched(WhenNotMatched::DoNothing)
7982 .try_build()
7983 .unwrap();
7984 let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
7985 [Ok(new_batch.clone())],
7986 id_only_schema.clone(),
7987 )));
7988 let plan = plan_job.create_plan(plan_stream).await.unwrap();
7989 assert_plan_node_equals(
7990 plan,
7991 "DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing
7992 ...
7993 HashJoinExec: ...join_type=Inner...
7994 ...
7995 ...
7996 StreamingTableExec: partition_sizes=1, projection=[key]",
7997 )
7998 .await
7999 .unwrap();
8000 let job = MergeInsertBuilder::try_new(ds.clone(), keys)
8001 .unwrap()
8002 .when_matched(WhenMatched::Delete)
8003 .when_not_matched(WhenNotMatched::DoNothing)
8004 .try_build()
8005 .unwrap();
8006
8007 let new_reader = Box::new(RecordBatchIterator::new(
8008 [Ok(new_batch)],
8009 id_only_schema.clone(),
8010 ));
8011 let new_stream = reader_to_stream(new_reader);
8012
8013 let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
8014
8015 assert_eq!(merge_stats.num_deleted_rows, 3);
8017 assert_eq!(merge_stats.num_inserted_rows, 0);
8018 assert_eq!(merge_stats.num_updated_rows, 0);
8019
8020 let full_schema = create_test_schema();
8022 let batches = merged_dataset
8023 .scan()
8024 .try_into_stream()
8025 .await
8026 .unwrap()
8027 .try_collect::<Vec<_>>()
8028 .await
8029 .unwrap();
8030
8031 let merged = concat_batches(&full_schema, &batches).unwrap();
8032 let mut remaining_keys: Vec<u32> = merged
8033 .column(0)
8034 .as_primitive::<UInt32Type>()
8035 .values()
8036 .to_vec();
8037 remaining_keys.sort();
8038 assert_eq!(remaining_keys, vec![1, 3, 5]);
8039 }
8040
8041 #[rstest::rstest]
8044 #[tokio::test]
8045 async fn test_when_matched_delete_with_insert(
8046 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
8047 ) {
8048 let schema = create_test_schema();
8049 let test_uri = "memory://test_delete_with_insert.lance";
8050
8051 let ds = create_test_dataset(test_uri, version, false).await;
8053
8054 let new_batch = create_new_batch(schema.clone());
8056
8057 let keys = vec!["key".to_string()];
8058
8059 let plan_job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
8061 .unwrap()
8062 .when_matched(WhenMatched::Delete)
8063 .when_not_matched(WhenNotMatched::InsertAll)
8064 .try_build()
8065 .unwrap();
8066 let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
8067 [Ok(new_batch.clone())],
8068 schema.clone(),
8069 )));
8070 let plan = plan_job.create_plan(plan_stream).await.unwrap();
8071 assert_plan_node_equals(
8072 plan,
8073 "MergeInsert: on=[key], when_matched=Delete, when_not_matched=InsertAll, when_not_matched_by_source=Keep...THEN 2 WHEN...THEN 3 ELSE 0 END as __action]...projection=[key, value, filterme]"
8074 ).await.unwrap();
8075
8076 let job = MergeInsertBuilder::try_new(ds.clone(), keys)
8078 .unwrap()
8079 .when_matched(WhenMatched::Delete)
8080 .when_not_matched(WhenNotMatched::InsertAll)
8081 .try_build()
8082 .unwrap();
8083
8084 let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
8085 let new_stream = reader_to_stream(new_reader);
8086
8087 let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
8088
8089 assert_eq!(merge_stats.num_deleted_rows, 3);
8091 assert_eq!(merge_stats.num_inserted_rows, 3);
8092 assert_eq!(merge_stats.num_updated_rows, 0);
8093
8094 let batches = merged_dataset
8096 .scan()
8097 .try_into_stream()
8098 .await
8099 .unwrap()
8100 .try_collect::<Vec<_>>()
8101 .await
8102 .unwrap();
8103
8104 let merged = concat_batches(&schema, &batches).unwrap();
8105 let mut remaining_keys: Vec<u32> = merged
8106 .column(0)
8107 .as_primitive::<UInt32Type>()
8108 .values()
8109 .to_vec();
8110 remaining_keys.sort();
8111 assert_eq!(remaining_keys, vec![1, 2, 3, 7, 8, 9]);
8112
8113 let keyvals: Vec<(u32, u32)> = merged
8115 .column(0)
8116 .as_primitive::<UInt32Type>()
8117 .values()
8118 .iter()
8119 .zip(
8120 merged
8121 .column(1)
8122 .as_primitive::<UInt32Type>()
8123 .values()
8124 .iter(),
8125 )
8126 .map(|(&k, &v)| (k, v))
8127 .collect();
8128
8129 for (key, value) in keyvals {
8130 if key <= 3 {
8131 assert_eq!(value, 1, "Original keys should have value=1");
8132 } else {
8133 assert_eq!(value, 2, "New keys should have value=2");
8134 }
8135 }
8136 }
8137
8138 #[rstest::rstest]
8141 #[tokio::test]
8142 async fn test_when_matched_delete_no_matches(
8143 #[values(LanceFileVersion::Legacy, LanceFileVersion::V2_0)] version: LanceFileVersion,
8144 ) {
8145 let schema = create_test_schema();
8146 let test_uri = "memory://test_delete_no_matches.lance";
8147
8148 let ds = create_test_dataset(test_uri, version, false).await;
8150
8151 let non_matching_batch = RecordBatch::try_new(
8153 schema.clone(),
8154 vec![
8155 Arc::new(UInt32Array::from(vec![100, 200, 300])),
8156 Arc::new(UInt32Array::from(vec![10, 20, 30])),
8157 Arc::new(StringArray::from(vec!["X", "Y", "Z"])),
8158 ],
8159 )
8160 .unwrap();
8161
8162 let keys = vec!["key".to_string()];
8163
8164 let plan_job = MergeInsertBuilder::try_new(ds.clone(), keys.clone())
8166 .unwrap()
8167 .when_matched(WhenMatched::Delete)
8168 .when_not_matched(WhenNotMatched::DoNothing)
8169 .try_build()
8170 .unwrap();
8171 let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
8172 [Ok(non_matching_batch.clone())],
8173 schema.clone(),
8174 )));
8175 let plan = plan_job.create_plan(plan_stream).await.unwrap();
8176 assert_plan_node_equals(
8177 plan,
8178 "DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing
8179 ...
8180 HashJoinExec: ...join_type=Inner...
8181 ...
8182 ...
8183 StreamingTableExec: partition_sizes=1, projection=[key]",
8184 )
8185 .await
8186 .unwrap();
8187 let job = MergeInsertBuilder::try_new(ds.clone(), keys)
8188 .unwrap()
8189 .when_matched(WhenMatched::Delete)
8190 .when_not_matched(WhenNotMatched::DoNothing)
8191 .try_build()
8192 .unwrap();
8193
8194 let new_reader = Box::new(RecordBatchIterator::new(
8195 [Ok(non_matching_batch)],
8196 schema.clone(),
8197 ));
8198 let new_stream = reader_to_stream(new_reader);
8199
8200 let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap();
8201
8202 assert_eq!(merge_stats.num_deleted_rows, 0);
8204 assert_eq!(merge_stats.num_inserted_rows, 0);
8205 assert_eq!(merge_stats.num_updated_rows, 0);
8206
8207 let batches = merged_dataset
8209 .scan()
8210 .try_into_stream()
8211 .await
8212 .unwrap()
8213 .try_collect::<Vec<_>>()
8214 .await
8215 .unwrap();
8216
8217 let merged = concat_batches(&schema, &batches).unwrap();
8218 let mut remaining_keys: Vec<u32> = merged
8219 .column(0)
8220 .as_primitive::<UInt32Type>()
8221 .values()
8222 .to_vec();
8223 remaining_keys.sort();
8224 assert_eq!(remaining_keys, vec![1, 2, 3, 4, 5, 6]);
8225 }
8226
8227 #[tokio::test]
8237 async fn test_is_delete_only() {
8238 use itertools::iproduct;
8239
8240 let when_matched_variants = [
8242 WhenMatched::UpdateAll,
8243 WhenMatched::DoNothing,
8244 WhenMatched::Fail,
8245 WhenMatched::Delete,
8246 ];
8247 let when_not_matched_variants = [WhenNotMatched::InsertAll, WhenNotMatched::DoNothing];
8248 let when_not_matched_by_source_variants =
8249 [WhenNotMatchedBySource::Keep, WhenNotMatchedBySource::Delete];
8250
8251 let schema = create_test_schema();
8252
8253 for (idx, (when_matched, when_not_matched, when_not_matched_by_source)) in iproduct!(
8254 when_matched_variants.iter().cloned(),
8255 when_not_matched_variants.iter().cloned(),
8256 when_not_matched_by_source_variants.iter().cloned()
8257 )
8258 .enumerate()
8259 {
8260 let is_no_op = matches!(when_matched, WhenMatched::DoNothing | WhenMatched::Fail)
8262 && matches!(when_not_matched, WhenNotMatched::DoNothing)
8263 && matches!(when_not_matched_by_source, WhenNotMatchedBySource::Keep);
8264 if is_no_op {
8265 continue;
8266 }
8267
8268 let test_uri = format!("memory://test_is_delete_only_{}.lance", idx);
8269 let ds = create_test_dataset(&test_uri, LanceFileVersion::V2_0, false).await;
8270
8271 let new_batch = RecordBatch::try_new(
8272 schema.clone(),
8273 vec![
8274 Arc::new(UInt32Array::from(vec![4, 5, 6])),
8275 Arc::new(UInt32Array::from(vec![2, 2, 2])),
8276 Arc::new(StringArray::from(vec!["A", "B", "C"])),
8277 ],
8278 )
8279 .unwrap();
8280
8281 let keys = vec!["key".to_string()];
8282
8283 let mut builder = MergeInsertBuilder::try_new(ds.clone(), keys).unwrap();
8284 builder
8285 .when_matched(when_matched.clone())
8286 .when_not_matched(when_not_matched.clone())
8287 .when_not_matched_by_source(when_not_matched_by_source.clone());
8288
8289 let job = builder.try_build().unwrap();
8290
8291 let plan_stream = reader_to_stream(Box::new(RecordBatchIterator::new(
8292 [Ok(new_batch)],
8293 schema.clone(),
8294 )));
8295 let plan = job.create_plan(plan_stream).await.unwrap();
8296
8297 let plan_str = datafusion::physical_plan::displayable(plan.as_ref())
8298 .indent(true)
8299 .to_string();
8300
8301 let expected_delete_only = matches!(when_matched, WhenMatched::Delete)
8302 && matches!(when_not_matched, WhenNotMatched::DoNothing)
8303 && matches!(when_not_matched_by_source, WhenNotMatchedBySource::Keep);
8304
8305 if expected_delete_only {
8306 assert!(
8307 plan_str.contains("DeleteOnlyMergeInsert"),
8308 "Expected DeleteOnlyMergeInsert for ({:?}, {:?}, {:?}), but got:\n{}",
8309 when_matched,
8310 when_not_matched,
8311 when_not_matched_by_source,
8312 plan_str
8313 );
8314 } else {
8315 assert!(
8316 plan_str.contains("MergeInsert:")
8317 && !plan_str.contains("DeleteOnlyMergeInsert"),
8318 "Expected MergeInsert (not DeleteOnlyMergeInsert) for ({:?}, {:?}, {:?}), but got:\n{}",
8319 when_matched,
8320 when_not_matched,
8321 when_not_matched_by_source,
8322 plan_str
8323 );
8324 }
8325 }
8326 }
8327
8328 #[tokio::test]
8330 async fn test_apply_deletions_invalid_row_address() {
8331 use super::exec::apply_deletions;
8332 use roaring::RoaringTreemap;
8333
8334 let test_uri = "memory://test_apply_deletions_error.lance";
8335
8336 let ds = create_test_dataset(test_uri, LanceFileVersion::V2_0, false).await;
8338 let fragment_id = ds.get_fragments()[0].id() as u32;
8339
8340 let mut invalid_row_addrs = RoaringTreemap::new();
8348 let base = (fragment_id as u64) << 32;
8349 for row_offset in 10..14u64 {
8351 invalid_row_addrs.insert(base | row_offset);
8352 }
8353
8354 let result = apply_deletions(&ds, &invalid_row_addrs).await;
8355
8356 assert!(result.is_err(), "Expected error for invalid row addresses");
8357 let err = result.unwrap_err();
8358 assert!(
8359 err.to_string()
8360 .contains("Deletion vector includes rows that aren't in the fragment"),
8361 "Expected 'rows that aren't in the fragment' error, got: {}",
8362 err
8363 );
8364 }
8365
8366 mod external_error {
8367 use super::*;
8368 use arrow_schema::{ArrowError, Field as ArrowField, Schema as ArrowSchema};
8369 use std::fmt;
8370
8371 #[derive(Debug)]
8372 struct MyTestError {
8373 code: i32,
8374 details: String,
8375 }
8376
8377 impl fmt::Display for MyTestError {
8378 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
8379 write!(f, "MyTestError({}): {}", self.code, self.details)
8380 }
8381 }
8382
8383 impl std::error::Error for MyTestError {}
8384
8385 #[tokio::test]
8386 async fn test_merge_insert_execute_reader_preserves_external_error() {
8387 let schema = Arc::new(ArrowSchema::new(vec![
8388 ArrowField::new("key", DataType::Int32, false),
8389 ArrowField::new("value", DataType::Int32, false),
8390 ]));
8391
8392 let batch = RecordBatch::try_new(
8394 schema.clone(),
8395 vec![
8396 Arc::new(Int32Array::from(vec![1, 2, 3])),
8397 Arc::new(Int32Array::from(vec![10, 20, 30])),
8398 ],
8399 )
8400 .unwrap();
8401 let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
8402 let dataset = Arc::new(
8403 Dataset::write(reader, "memory://test_merge_external", None)
8404 .await
8405 .unwrap(),
8406 );
8407
8408 let error_code = 789;
8410 let iter = std::iter::once(Err(ArrowError::ExternalError(Box::new(MyTestError {
8411 code: error_code,
8412 details: "merge insert failure".to_string(),
8413 }))));
8414 let reader = RecordBatchIterator::new(iter, schema);
8415
8416 let result = MergeInsertBuilder::try_new(dataset, vec!["key".to_string()])
8417 .unwrap()
8418 .try_build()
8419 .unwrap()
8420 .execute_reader(Box::new(reader) as Box<dyn RecordBatchReader + Send>)
8421 .await;
8422
8423 match result {
8424 Err(Error::External { source }) => {
8425 let original = source.downcast_ref::<MyTestError>().unwrap();
8426 assert_eq!(original.code, error_code);
8427 }
8428 Err(other) => panic!("Expected External, got: {:?}", other),
8429 Ok(_) => panic!("Expected error"),
8430 }
8431 }
8432 }
8433
8434 async fn create_indexed_3frag_dataset() -> Arc<Dataset> {
8441 let schema = Arc::new(Schema::new(vec![
8442 Field::new("id", DataType::Utf8, false),
8443 Field::new("category", DataType::Utf8, false),
8444 Field::new("value_a", DataType::Float64, false),
8445 Field::new("value_b", DataType::Float64, false),
8446 ]));
8447
8448 let make_batch = |frag_idx: usize| {
8449 let start = frag_idx * 100;
8450 let ids: Vec<String> = (start..start + 100).map(|j| format!("id-{j:04}")).collect();
8451 let categories: Vec<&str> = vec!["A"; 100];
8452 let value_a: Vec<f64> = (0..100)
8453 .map(|i| i as f64 + frag_idx as f64 * 100.0)
8454 .collect();
8455 let value_b: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
8456 RecordBatch::try_new(
8457 schema.clone(),
8458 vec![
8459 Arc::new(StringArray::from(ids)),
8460 Arc::new(StringArray::from(categories)),
8461 Arc::new(Float64Array::from(value_a)),
8462 Arc::new(Float64Array::from(value_b)),
8463 ],
8464 )
8465 .unwrap()
8466 };
8467
8468 let batch0 = make_batch(0);
8470 let reader = Box::new(RecordBatchIterator::new([Ok(batch0)], schema.clone()));
8471 let mut ds = Dataset::write(reader, "memory://indexed_3frag", None)
8472 .await
8473 .unwrap();
8474
8475 for frag_idx in 1..3 {
8477 let batch = make_batch(frag_idx);
8478 let reader = Box::new(RecordBatchIterator::new([Ok(batch)], schema.clone()));
8479 ds.append(reader, None).await.unwrap();
8480 }
8481
8482 ds.create_index(
8484 &["id"],
8485 IndexType::BTree,
8486 None,
8487 &ScalarIndexParams::default(),
8488 false,
8489 )
8490 .await
8491 .unwrap();
8492
8493 Arc::new(ds)
8494 }
8495
8496 async fn partial_merge_insert(
8500 dataset: Arc<Dataset>,
8501 id_range: std::ops::Range<usize>,
8502 value_a_val: f64,
8503 ) -> Arc<Dataset> {
8504 let ids: Vec<String> = id_range.map(|j| format!("id-{j:04}")).collect();
8505 let n = ids.len();
8506 let sub_schema = Arc::new(Schema::new(vec![
8507 Field::new("id", DataType::Utf8, false),
8508 Field::new("value_a", DataType::Float64, false),
8509 ]));
8510 let batch = RecordBatch::try_new(
8511 sub_schema.clone(),
8512 vec![
8513 Arc::new(StringArray::from(ids)),
8514 Arc::new(Float64Array::from(vec![value_a_val; n])),
8515 ],
8516 )
8517 .unwrap();
8518 let reader = Box::new(RecordBatchIterator::new([Ok(batch)], sub_schema));
8519
8520 let (ds, _) = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()])
8521 .unwrap()
8522 .when_matched(WhenMatched::UpdateAll)
8523 .when_not_matched(WhenNotMatched::DoNothing)
8524 .try_build()
8525 .unwrap()
8526 .execute_reader(reader)
8527 .await
8528 .unwrap();
8529 ds
8530 }
8531
8532 #[tokio::test]
8539 async fn test_partial_merge_insert_stale_index_ambiguous() {
8540 let dataset = create_indexed_3frag_dataset().await;
8541
8542 let dataset = partial_merge_insert(dataset, 100..200, 999.0).await;
8544
8545 let dataset = partial_merge_insert(dataset, 100..200, 888.0).await;
8548
8549 let batches = dataset
8551 .scan()
8552 .try_into_stream()
8553 .await
8554 .unwrap()
8555 .try_collect::<Vec<_>>()
8556 .await
8557 .unwrap();
8558 let all_schema = Arc::new(Schema::new(vec![
8559 Field::new("id", DataType::Utf8, false),
8560 Field::new("category", DataType::Utf8, false),
8561 Field::new("value_a", DataType::Float64, false),
8562 Field::new("value_b", DataType::Float64, false),
8563 ]));
8564 let combined = concat_batches(&all_schema, &batches).unwrap();
8565 assert_eq!(combined.num_rows(), 300);
8566
8567 let result = dataset
8569 .scan()
8570 .filter("id >= 'id-0100' AND id < 'id-0200'")
8571 .unwrap()
8572 .try_into_stream()
8573 .await
8574 .unwrap()
8575 .try_collect::<Vec<_>>()
8576 .await
8577 .unwrap();
8578 let result = concat_batches(&all_schema, &result).unwrap();
8579 assert_eq!(result.num_rows(), 100);
8580 let values = result
8581 .column_by_name("value_a")
8582 .unwrap()
8583 .as_any()
8584 .downcast_ref::<Float64Array>()
8585 .unwrap();
8586 for i in 0..100 {
8587 assert_eq!(values.value(i), 888.0, "row {i} should have value_a=888.0");
8588 }
8589 }
8590
8591 #[tokio::test]
8599 async fn test_partial_merge_insert_stale_index_fragment_not_exist() {
8600 let dataset = create_indexed_3frag_dataset().await;
8601
8602 let dataset = partial_merge_insert(dataset, 100..200, 999.0).await;
8604
8605 let update_result = crate::dataset::UpdateBuilder::new(Arc::new((*dataset).clone()))
8608 .update_where("id >= 'id-0100' AND id < 'id-0200'")
8609 .unwrap()
8610 .set("category", "'B'")
8611 .unwrap()
8612 .build()
8613 .unwrap()
8614 .execute()
8615 .await
8616 .unwrap();
8617 let dataset = update_result.new_dataset;
8618
8619 let dataset = partial_merge_insert(dataset, 100..200, 888.0).await;
8622
8623 let batches = dataset
8625 .scan()
8626 .try_into_stream()
8627 .await
8628 .unwrap()
8629 .try_collect::<Vec<_>>()
8630 .await
8631 .unwrap();
8632 let all_schema = Arc::new(Schema::new(vec![
8633 Field::new("id", DataType::Utf8, false),
8634 Field::new("category", DataType::Utf8, false),
8635 Field::new("value_a", DataType::Float64, false),
8636 Field::new("value_b", DataType::Float64, false),
8637 ]));
8638 let combined = concat_batches(&all_schema, &batches).unwrap();
8639 assert_eq!(combined.num_rows(), 300);
8640 }
8641
8642 #[tokio::test]
8651 async fn test_partial_merge_insert_stale_index_batch_size_mismatch() {
8652 let dataset = create_indexed_3frag_dataset().await;
8653
8654 let dataset = partial_merge_insert(dataset, 100..200, 999.0).await;
8656
8657 let update_result = crate::dataset::UpdateBuilder::new(Arc::new((*dataset).clone()))
8660 .update_where("id >= 'id-0100' AND id < 'id-0150'")
8661 .unwrap()
8662 .set("category", "'B'")
8663 .unwrap()
8664 .build()
8665 .unwrap()
8666 .execute()
8667 .await
8668 .unwrap();
8669 let dataset = update_result.new_dataset;
8670
8671 let dataset = partial_merge_insert(dataset, 100..150, 888.0).await;
8674
8675 let batches = dataset
8677 .scan()
8678 .try_into_stream()
8679 .await
8680 .unwrap()
8681 .try_collect::<Vec<_>>()
8682 .await
8683 .unwrap();
8684 let all_schema = Arc::new(Schema::new(vec![
8685 Field::new("id", DataType::Utf8, false),
8686 Field::new("category", DataType::Utf8, false),
8687 Field::new("value_a", DataType::Float64, false),
8688 Field::new("value_b", DataType::Float64, false),
8689 ]));
8690 let combined = concat_batches(&all_schema, &batches).unwrap();
8691 assert_eq!(combined.num_rows(), 300);
8692 }
8693
8694 #[tokio::test]
8699 async fn test_partial_merge_insert_stale_vector_index_duplicates() {
8700 let dim = 4i32;
8701 let rows_per_frag = 10usize;
8702 let num_frags = 3usize;
8703 let total_rows = rows_per_frag * num_frags;
8704
8705 let schema = Arc::new(Schema::new(vec![
8706 Field::new("id", DataType::Utf8, false),
8707 Field::new("category", DataType::Utf8, false),
8708 Field::new(
8709 "vec",
8710 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
8711 false,
8712 ),
8713 ]));
8714
8715 let make_batch = |frag_idx: usize, offset: f32| {
8716 let start = frag_idx * rows_per_frag;
8717 let ids: Vec<String> = (start..start + rows_per_frag)
8718 .map(|j| format!("id-{j:04}"))
8719 .collect();
8720 let cats: Vec<&str> = vec!["A"; rows_per_frag];
8721 let values: Vec<f32> = (0..rows_per_frag * dim as usize)
8722 .map(|i| (start * dim as usize + i) as f32 + offset)
8723 .collect();
8724 let vectors =
8725 FixedSizeListArray::try_new_from_values(Float32Array::from(values), dim).unwrap();
8726 RecordBatch::try_new(
8727 schema.clone(),
8728 vec![
8729 Arc::new(StringArray::from(ids)),
8730 Arc::new(StringArray::from(cats)),
8731 Arc::new(vectors),
8732 ],
8733 )
8734 .unwrap()
8735 };
8736
8737 let batch0 = make_batch(0, 0.0);
8739 let reader = Box::new(RecordBatchIterator::new([Ok(batch0)], schema.clone()));
8740 let mut ds = Dataset::write(reader, "memory://vector_stale_test", None)
8741 .await
8742 .unwrap();
8743 for frag_idx in 1..num_frags {
8744 let batch = make_batch(frag_idx, 0.0);
8745 let reader = Box::new(RecordBatchIterator::new([Ok(batch)], schema.clone()));
8746 ds.append(reader, None).await.unwrap();
8747 }
8748
8749 let params = VectorIndexParams::ivf_flat(1, MetricType::L2);
8751 ds.create_index(&["vec"], IndexType::Vector, None, ¶ms, false)
8752 .await
8753 .unwrap();
8754
8755 let ds = Arc::new(ds);
8756
8757 let frag1_start = rows_per_frag;
8760 let ids: Vec<String> = (frag1_start..frag1_start + rows_per_frag)
8761 .map(|j| format!("id-{j:04}"))
8762 .collect();
8763 let sub_schema = Arc::new(Schema::new(vec![
8764 Field::new("id", DataType::Utf8, false),
8765 Field::new(
8766 "vec",
8767 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
8768 false,
8769 ),
8770 ]));
8771 let values: Vec<f32> = (0..rows_per_frag * dim as usize)
8772 .map(|i| (frag1_start * dim as usize + i) as f32 + 0.5)
8773 .collect();
8774 let vectors =
8775 FixedSizeListArray::try_new_from_values(Float32Array::from(values), dim).unwrap();
8776 let update_batch = RecordBatch::try_new(
8777 sub_schema.clone(),
8778 vec![Arc::new(StringArray::from(ids)), Arc::new(vectors)],
8779 )
8780 .unwrap();
8781 let reader = Box::new(RecordBatchIterator::new([Ok(update_batch)], sub_schema));
8782 let (ds, _) = MergeInsertBuilder::try_new(ds, vec!["id".to_string()])
8783 .unwrap()
8784 .when_matched(WhenMatched::UpdateAll)
8785 .when_not_matched(WhenNotMatched::DoNothing)
8786 .try_build()
8787 .unwrap()
8788 .execute_reader(reader)
8789 .await
8790 .unwrap();
8791
8792 let query: Float32Array = (0..dim)
8794 .map(|i| (frag1_start * dim as usize + i as usize) as f32 + 0.5)
8795 .collect();
8796 let results = ds
8797 .scan()
8798 .nearest("vec", &query, total_rows)
8799 .unwrap()
8800 .try_into_batch()
8801 .await
8802 .unwrap();
8803
8804 let ids = results
8806 .column_by_name("id")
8807 .unwrap()
8808 .as_any()
8809 .downcast_ref::<StringArray>()
8810 .unwrap();
8811 let unique_ids: std::collections::HashSet<&str> =
8812 (0..ids.len()).map(|i| ids.value(i)).collect();
8813 assert_eq!(
8814 unique_ids.len(),
8815 ids.len(),
8816 "Found duplicate ids in KNN results: {} unique out of {} total",
8817 unique_ids.len(),
8818 ids.len()
8819 );
8820 }
8821
8822 #[tokio::test]
8827 async fn test_partial_merge_insert_stale_fts_index_duplicates() {
8828 let rows_per_frag = 10usize;
8829 let num_frags = 3usize;
8830
8831 let schema = Arc::new(Schema::new(vec![
8832 Field::new("id", DataType::Utf8, false),
8833 Field::new("category", DataType::Utf8, false),
8834 Field::new("text", DataType::Utf8, false),
8835 ]));
8836
8837 let make_batch = |frag_idx: usize| {
8838 let start = frag_idx * rows_per_frag;
8839 let ids: Vec<String> = (start..start + rows_per_frag)
8840 .map(|j| format!("id-{j:04}"))
8841 .collect();
8842 let cats: Vec<&str> = vec!["A"; rows_per_frag];
8843 let texts: Vec<String> = (start..start + rows_per_frag)
8845 .map(|j| format!("common unique{j:04}"))
8846 .collect();
8847 RecordBatch::try_new(
8848 schema.clone(),
8849 vec![
8850 Arc::new(StringArray::from(ids)),
8851 Arc::new(StringArray::from(cats)),
8852 Arc::new(StringArray::from(texts)),
8853 ],
8854 )
8855 .unwrap()
8856 };
8857
8858 let batch0 = make_batch(0);
8860 let reader = Box::new(RecordBatchIterator::new([Ok(batch0)], schema.clone()));
8861 let mut ds = Dataset::write(reader, "memory://fts_stale_test", None)
8862 .await
8863 .unwrap();
8864 for frag_idx in 1..num_frags {
8865 let batch = make_batch(frag_idx);
8866 let reader = Box::new(RecordBatchIterator::new([Ok(batch)], schema.clone()));
8867 ds.append(reader, None).await.unwrap();
8868 }
8869
8870 let params = InvertedIndexParams::default();
8872 ds.create_index(&["text"], IndexType::Inverted, None, ¶ms, true)
8873 .await
8874 .unwrap();
8875
8876 let ds = Arc::new(ds);
8877
8878 let frag1_start = rows_per_frag;
8882 let ids: Vec<String> = (frag1_start..frag1_start + rows_per_frag)
8883 .map(|j| format!("id-{j:04}"))
8884 .collect();
8885 let texts: Vec<String> = (frag1_start..frag1_start + rows_per_frag)
8886 .map(|j| format!("common updated{j:04}"))
8887 .collect();
8888 let sub_schema = Arc::new(Schema::new(vec![
8889 Field::new("id", DataType::Utf8, false),
8890 Field::new("text", DataType::Utf8, false),
8891 ]));
8892 let update_batch = RecordBatch::try_new(
8893 sub_schema.clone(),
8894 vec![
8895 Arc::new(StringArray::from(ids)),
8896 Arc::new(StringArray::from(texts)),
8897 ],
8898 )
8899 .unwrap();
8900 let reader = Box::new(RecordBatchIterator::new([Ok(update_batch)], sub_schema));
8901 let (ds, _) = MergeInsertBuilder::try_new(ds, vec!["id".to_string()])
8902 .unwrap()
8903 .when_matched(WhenMatched::UpdateAll)
8904 .when_not_matched(WhenNotMatched::DoNothing)
8905 .try_build()
8906 .unwrap()
8907 .execute_reader(reader)
8908 .await
8909 .unwrap();
8910
8911 let query = FullTextSearchQuery::new("common".to_string());
8913 let results = ds
8914 .scan()
8915 .full_text_search(query)
8916 .unwrap()
8917 .try_into_batch()
8918 .await
8919 .unwrap();
8920
8921 let ids = results
8923 .column_by_name("id")
8924 .unwrap()
8925 .as_any()
8926 .downcast_ref::<StringArray>()
8927 .unwrap();
8928 let unique_ids: std::collections::HashSet<&str> =
8929 (0..ids.len()).map(|i| ids.value(i)).collect();
8930 assert_eq!(
8931 unique_ids.len(),
8932 ids.len(),
8933 "Found duplicate ids in FTS results: {} unique out of {} total",
8934 unique_ids.len(),
8935 ids.len()
8936 );
8937 assert_eq!(
8939 unique_ids.len(),
8940 rows_per_frag * num_frags,
8941 "Expected {} rows but got {}",
8942 rows_per_frag * num_frags,
8943 unique_ids.len()
8944 );
8945 }
8946
8947 #[tokio::test]
8956 async fn test_compaction_after_invalidated_fragment() {
8957 use crate::dataset::optimize::{CompactionOptions, compact_files};
8958
8959 let rows_per_frag = 20;
8963 let num_frags = 5;
8964 let total_rows = rows_per_frag * num_frags;
8965 let schema = Arc::new(Schema::new(vec![
8966 Field::new("id", DataType::Utf8, false),
8967 Field::new("category", DataType::Utf8, false),
8968 Field::new("value_a", DataType::Float64, false),
8969 Field::new("value_b", DataType::Float64, false),
8970 ]));
8971
8972 let make_batch = |frag_idx: usize| {
8973 let start = frag_idx * rows_per_frag;
8974 let ids: Vec<String> = (start..start + rows_per_frag)
8975 .map(|j| format!("id-{j:04}"))
8976 .collect();
8977 RecordBatch::try_new(
8978 schema.clone(),
8979 vec![
8980 Arc::new(StringArray::from(ids)),
8981 Arc::new(StringArray::from(vec!["A"; rows_per_frag])),
8982 Arc::new(Float64Array::from(
8983 (0..rows_per_frag).map(|i| i as f64).collect::<Vec<_>>(),
8984 )),
8985 Arc::new(Float64Array::from(
8986 (0..rows_per_frag)
8987 .map(|i| i as f64 * 0.1)
8988 .collect::<Vec<_>>(),
8989 )),
8990 ],
8991 )
8992 .unwrap()
8993 };
8994
8995 let batch0 = make_batch(0);
8996 let reader = Box::new(RecordBatchIterator::new([Ok(batch0)], schema.clone()));
8997 let mut ds = Dataset::write(reader, "memory://compaction_test", None)
8998 .await
8999 .unwrap();
9000 for frag_idx in 1..num_frags {
9001 let batch = make_batch(frag_idx);
9002 let reader = Box::new(RecordBatchIterator::new([Ok(batch)], schema.clone()));
9003 ds.append(reader, None).await.unwrap();
9004 }
9005 ds.create_index(
9006 &["id"],
9007 IndexType::BTree,
9008 None,
9009 &ScalarIndexParams::default(),
9010 false,
9011 )
9012 .await
9013 .unwrap();
9014
9015 let ds = Arc::new(ds);
9016
9017 let frag2_start = 2 * rows_per_frag;
9019 let ds = partial_merge_insert(ds, frag2_start..frag2_start + rows_per_frag, 999.0).await;
9020
9021 let indices = ds.load_indices().await.unwrap();
9023 let idx = indices.iter().find(|i| i.name == "id_idx").unwrap();
9024 assert!(!idx.fragment_bitmap.as_ref().unwrap().contains(2));
9025
9026 let mut ds = (*ds).clone();
9028 let opts = CompactionOptions {
9029 target_rows_per_fragment: total_rows,
9030 ..Default::default()
9031 };
9032 compact_files(&mut ds, opts, None).await.unwrap();
9033
9034 let indices = ds.load_indices().await.unwrap();
9038 let idx = indices.iter().find(|i| i.name == "id_idx").unwrap();
9039 let bitmap = idx.fragment_bitmap.as_ref().unwrap();
9040 for &old_id in &[0u32, 1, 3, 4] {
9041 assert!(
9042 !bitmap.contains(old_id),
9043 "Old indexed fragment {} should not be in bitmap after compaction",
9044 old_id
9045 );
9046 }
9047 assert!(
9048 !bitmap.is_empty(),
9049 "Bitmap should have new compacted fragments"
9050 );
9051
9052 let ds = Arc::new(ds);
9058 let ds = partial_merge_insert(ds, frag2_start..frag2_start + rows_per_frag, 888.0).await;
9059
9060 let batches = ds
9062 .scan()
9063 .try_into_stream()
9064 .await
9065 .unwrap()
9066 .try_collect::<Vec<_>>()
9067 .await
9068 .unwrap();
9069 let combined = concat_batches(&schema, &batches).unwrap();
9070 assert_eq!(combined.num_rows(), total_rows);
9071
9072 let result = ds
9074 .scan()
9075 .filter(&format!(
9076 "id >= 'id-{:04}' AND id < 'id-{:04}'",
9077 frag2_start,
9078 frag2_start + rows_per_frag
9079 ))
9080 .unwrap()
9081 .try_into_stream()
9082 .await
9083 .unwrap()
9084 .try_collect::<Vec<_>>()
9085 .await
9086 .unwrap();
9087 let result = concat_batches(&schema, &result).unwrap();
9088 assert_eq!(result.num_rows(), rows_per_frag);
9089 let values = result
9090 .column_by_name("value_a")
9091 .unwrap()
9092 .as_any()
9093 .downcast_ref::<Float64Array>()
9094 .unwrap();
9095 for i in 0..rows_per_frag {
9096 assert_eq!(values.value(i), 888.0, "row {i} should have value_a=888.0");
9097 }
9098 }
9099}