1use arrow_array::{
2 Array, ArrayRef, RecordBatch, StringArray, StructArray, UInt8Array, UInt32Array, UInt64Array,
3};
4use arrow_schema::SchemaRef;
5use arrow_select::concat::concat_batches;
6use futures::TryStreamExt;
7use lance::Dataset;
8use lance::blob::BlobArrayBuilder;
9use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner};
10use lance::dataset::transaction::{Operation, Transaction, TransactionBuilder};
11use lance::dataset::write::merge_insert::SourceDedupeBehavior;
12use lance::dataset::{
13 CommitBuilder, InsertBuilder, MergeInsertBuilder, WhenMatched, WhenNotMatched, WriteMode,
14 WriteParams,
15};
16use lance::datatypes::BlobKind;
17use lance::index::DatasetIndexExt;
18use lance::index::scalar::IndexDetails;
19use lance_file::version::LanceFileVersion;
20use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams};
21use lance_index::{IndexType, is_system_index};
22use lance_linalg::distance::MetricType;
23use lance_table::format::{Fragment, IndexMetadata, RowIdMeta};
24use lance_table::rowids::{RowIdSequence, write_row_ids};
25use std::sync::Arc;
26
27use crate::db::manifest::{TableVersionMetadata, open_table_head_for_write};
28use crate::db::{Snapshot, SubTableEntry};
29use crate::error::{OmniError, Result};
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct TableState {
33 pub version: u64,
34 pub row_count: u64,
35 pub(crate) version_metadata: TableVersionMetadata,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct DeleteState {
40 pub version: u64,
41 pub row_count: u64,
42 pub deleted_rows: usize,
43 pub(crate) version_metadata: TableVersionMetadata,
44}
45
46#[derive(Debug, Clone)]
68pub struct StagedWrite {
69 pub transaction: Transaction,
70 pub new_fragments: Vec<Fragment>,
77 pub removed_fragment_ids: Vec<u64>,
84}
85
86#[derive(Debug, Clone)]
87pub struct TableStore {
88 root_uri: String,
89}
90
91impl TableStore {
92 pub fn new(root_uri: &str) -> Self {
93 Self {
94 root_uri: root_uri.trim_end_matches('/').to_string(),
95 }
96 }
97
98 pub fn root_uri(&self) -> &str {
99 &self.root_uri
100 }
101
102 pub fn dataset_uri(&self, table_path: &str) -> String {
103 format!("{}/{}", self.root_uri, table_path)
104 }
105
106 fn table_path_from_dataset_uri(&self, dataset_uri: &str) -> Result<String> {
107 let prefix = format!("{}/", self.root_uri.trim_end_matches('/'));
108 let table_path = dataset_uri
109 .strip_prefix(&prefix)
110 .map(|path| path.to_string())
111 .ok_or_else(|| {
112 OmniError::manifest_internal(format!(
113 "dataset uri '{}' is not under root '{}'",
114 dataset_uri, self.root_uri
115 ))
116 })?;
117 Ok(table_path
118 .split_once("/tree/")
119 .map(|(path, _)| path.to_string())
120 .unwrap_or(table_path))
121 }
122
123 fn dataset_version_metadata(
124 &self,
125 dataset_uri: &str,
126 ds: &Dataset,
127 ) -> Result<TableVersionMetadata> {
128 let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
129 TableVersionMetadata::from_dataset(&self.root_uri, &table_path, ds)
130 }
131
132 pub async fn open_snapshot_table(
133 &self,
134 snapshot: &Snapshot,
135 table_key: &str,
136 ) -> Result<Dataset> {
137 snapshot.open(table_key).await
138 }
139
140 pub async fn open_at_entry(&self, entry: &SubTableEntry) -> Result<Dataset> {
141 entry.open(&self.root_uri).await
142 }
143
144 pub async fn open_dataset_head(
145 &self,
146 dataset_uri: &str,
147 branch: Option<&str>,
148 ) -> Result<Dataset> {
149 let ds = Dataset::open(dataset_uri)
150 .await
151 .map_err(|e| OmniError::Lance(e.to_string()))?;
152 match branch {
153 Some(branch) if branch != "main" => ds
154 .checkout_branch(branch)
155 .await
156 .map_err(|e| OmniError::Lance(e.to_string())),
157 _ => Ok(ds),
158 }
159 }
160
161 pub async fn open_dataset_head_for_write(
162 &self,
163 table_key: &str,
164 dataset_uri: &str,
165 branch: Option<&str>,
166 ) -> Result<Dataset> {
167 let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
168 open_table_head_for_write(&self.root_uri, table_key, &table_path, branch).await
169 }
170
171 pub async fn delete_branch(&self, dataset_uri: &str, branch: &str) -> Result<()> {
172 let mut ds = Dataset::open(dataset_uri)
173 .await
174 .map_err(|e| OmniError::Lance(e.to_string()))?;
175 ds.delete_branch(branch)
176 .await
177 .map_err(|e| OmniError::Lance(e.to_string()))
178 }
179
180 pub async fn list_branches(&self, dataset_uri: &str) -> Result<Vec<String>> {
185 let ds = Dataset::open(dataset_uri)
186 .await
187 .map_err(|e| OmniError::Lance(e.to_string()))?;
188 let branches = ds
189 .list_branches()
190 .await
191 .map_err(|e| OmniError::Lance(e.to_string()))?;
192 Ok(branches.into_keys().collect())
193 }
194
195 pub async fn force_delete_branch(&self, dataset_uri: &str, branch: &str) -> Result<()> {
209 let mut ds = Dataset::open(dataset_uri)
210 .await
211 .map_err(|e| OmniError::Lance(e.to_string()))?;
212 match ds.force_delete_branch(branch).await {
213 Ok(()) => Ok(()),
214 Err(lance::Error::RefNotFound { .. }) | Err(lance::Error::NotFound { .. }) => Ok(()),
215 Err(e) => Err(OmniError::Lance(e.to_string())),
216 }
217 }
218
219 pub async fn open_dataset_at_state(
220 &self,
221 table_path: &str,
222 branch: Option<&str>,
223 version: u64,
224 ) -> Result<Dataset> {
225 let ds = self
226 .open_dataset_head(&self.dataset_uri(table_path), branch)
227 .await?;
228 ds.checkout_version(version)
229 .await
230 .map_err(|e| OmniError::Lance(e.to_string()))
231 }
232
233 pub fn ensure_expected_version(
234 &self,
235 ds: &Dataset,
236 table_key: &str,
237 expected_version: u64,
238 ) -> Result<()> {
239 let actual = ds.version().version;
240 if actual != expected_version {
241 return Err(OmniError::manifest_expected_version_mismatch(
247 table_key,
248 expected_version,
249 actual,
250 ));
251 }
252 Ok(())
253 }
254
255 pub async fn reopen_for_mutation(
256 &self,
257 dataset_uri: &str,
258 branch: Option<&str>,
259 table_key: &str,
260 expected_version: u64,
261 ) -> Result<Dataset> {
262 let ds = self
263 .open_dataset_head_for_write(table_key, dataset_uri, branch)
264 .await?;
265 self.ensure_expected_version(&ds, table_key, expected_version)?;
266 Ok(ds)
267 }
268
269 pub async fn fork_branch_from_state(
270 &self,
271 dataset_uri: &str,
272 source_branch: Option<&str>,
273 table_key: &str,
274 source_version: u64,
275 target_branch: &str,
276 ) -> Result<Dataset> {
277 let mut source_ds = self
278 .open_dataset_head(dataset_uri, source_branch)
279 .await?
280 .checkout_version(source_version)
281 .await
282 .map_err(|e| OmniError::Lance(e.to_string()))?;
283 self.ensure_expected_version(&source_ds, table_key, source_version)?;
284
285 if source_ds
286 .create_branch(target_branch, source_version, None)
287 .await
288 .is_err()
289 {
290 return Err(OmniError::manifest_conflict(format!(
298 "branch '{}' has orphaned table state for '{}' from an incomplete \
299 prior delete; run `omnigraph cleanup` to reclaim it before reusing \
300 this branch name",
301 target_branch, table_key
302 )));
303 }
304
305 let ds = self
306 .open_dataset_head(dataset_uri, Some(target_branch))
307 .await?;
308 self.ensure_expected_version(&ds, table_key, source_version)?;
309 Ok(ds)
310 }
311
312 pub async fn scan_batches(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
313 self.scan(ds, None, None, None).await
314 }
315
316 pub async fn scan_batches_for_rewrite(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
317 let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
318 if !has_blob_columns {
319 return self.scan_batches(ds).await;
320 }
321
322 let batches = Self::scan_stream(ds, None, None, None, true)
323 .await?
324 .try_collect::<Vec<RecordBatch>>()
325 .await
326 .map_err(|e| OmniError::Lance(e.to_string()))?;
327 let mut materialized = Vec::with_capacity(batches.len());
328 for batch in batches {
329 materialized.push(Self::materialize_blob_batch(ds, batch).await?);
330 }
331 Ok(materialized)
332 }
333
334 pub(crate) async fn materialize_blob_batch(
335 ds: &Dataset,
336 batch: RecordBatch,
337 ) -> Result<RecordBatch> {
338 let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
339 if !has_blob_columns {
340 return Ok(batch);
341 }
342
343 let row_ids = batch
344 .column_by_name("_rowid")
345 .and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
346 .ok_or_else(|| {
347 OmniError::Lance("expected _rowid column when materializing blobs".to_string())
348 })?
349 .values()
350 .iter()
351 .copied()
352 .collect::<Vec<_>>();
353
354 let schema: SchemaRef = Arc::new(ds.schema().into());
355 let mut columns = Vec::with_capacity(schema.fields().len());
356 for field in schema.fields() {
357 let lance_field = lance::datatypes::Field::try_from(field.as_ref())
358 .map_err(|e| OmniError::Lance(e.to_string()))?;
359 let column = batch.column_by_name(field.name()).ok_or_else(|| {
360 OmniError::Lance(format!("batch missing column '{}'", field.name()))
361 })?;
362 if lance_field.is_blob() {
363 let descriptions =
364 column
365 .as_any()
366 .downcast_ref::<StructArray>()
367 .ok_or_else(|| {
368 OmniError::Lance(format!(
369 "expected blob descriptions for '{}'",
370 field.name()
371 ))
372 })?;
373 columns.push(
374 Self::rebuild_blob_column(ds, field.name(), descriptions, &row_ids).await?,
375 );
376 } else {
377 columns.push(column.clone());
378 }
379 }
380
381 RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
382 }
383
384 async fn rebuild_blob_column(
385 ds: &Dataset,
386 column_name: &str,
387 descriptions: &StructArray,
388 row_ids: &[u64],
389 ) -> Result<ArrayRef> {
390 let mut builder = BlobArrayBuilder::new(row_ids.len());
391 let mut non_null_row_ids = Vec::new();
392 let mut row_has_blob = Vec::with_capacity(row_ids.len());
393
394 for row in 0..row_ids.len() {
395 let is_null = Self::blob_description_is_null(descriptions, row)?;
396 row_has_blob.push(!is_null);
397 if !is_null {
398 non_null_row_ids.push(row_ids[row]);
399 }
400 }
401
402 let blob_files = if non_null_row_ids.is_empty() {
403 Vec::new()
404 } else {
405 Arc::new(ds.clone())
406 .take_blobs(&non_null_row_ids, column_name)
407 .await
408 .map_err(|e| OmniError::Lance(e.to_string()))?
409 };
410
411 let mut files = blob_files.into_iter();
412 for has_blob in row_has_blob {
413 if !has_blob {
414 builder
415 .push_null()
416 .map_err(|e| OmniError::Lance(e.to_string()))?;
417 continue;
418 }
419
420 let blob = files.next().ok_or_else(|| {
421 OmniError::Lance(format!(
422 "blob rewrite for '{}' lost alignment with source rows",
423 column_name
424 ))
425 })?;
426 builder
427 .push_bytes(
428 blob.read()
429 .await
430 .map_err(|e| OmniError::Lance(e.to_string()))?,
431 )
432 .map_err(|e| OmniError::Lance(e.to_string()))?;
433 }
434
435 if files.next().is_some() {
436 return Err(OmniError::Lance(format!(
437 "blob rewrite for '{}' produced extra source blobs",
438 column_name
439 )));
440 }
441
442 builder
443 .finish()
444 .map_err(|e| OmniError::Lance(e.to_string()))
445 }
446
447 fn blob_description_is_null(descriptions: &StructArray, row: usize) -> Result<bool> {
448 if descriptions.is_null(row) {
449 return Ok(true);
450 }
451
452 let position = descriptions
453 .column_by_name("position")
454 .and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
455 .ok_or_else(|| {
456 OmniError::Lance(format!(
457 "unrecognized blob description schema {:?}: missing UInt64 position field",
458 descriptions.fields()
459 ))
460 })?;
461 let size = descriptions
462 .column_by_name("size")
463 .and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
464 .ok_or_else(|| {
465 OmniError::Lance(format!(
466 "unrecognized blob description schema {:?}: missing UInt64 size field",
467 descriptions.fields()
468 ))
469 })?;
470
471 let Some(kind_column) = descriptions.column_by_name("kind") else {
472 return Ok(position.is_null(row) || size.is_null(row));
473 };
474 let kind = if let Some(kind) = kind_column.as_any().downcast_ref::<UInt8Array>() {
475 if kind.is_null(row) {
476 return Ok(true);
477 }
478 kind.value(row)
479 } else if let Some(kind) = kind_column.as_any().downcast_ref::<UInt32Array>() {
480 if kind.is_null(row) {
481 return Ok(true);
482 }
483 kind.value(row) as u8
484 } else {
485 return Err(OmniError::Lance(format!(
486 "unrecognized blob description schema {:?}: kind field must be UInt8 or UInt32",
487 descriptions.fields()
488 )));
489 };
490
491 let kind = BlobKind::try_from(kind).map_err(|e| OmniError::Lance(e.to_string()))?;
492 if kind != BlobKind::Inline {
493 return Ok(false);
494 }
495 let blob_uri = descriptions
496 .column_by_name("blob_uri")
497 .and_then(|col| col.as_any().downcast_ref::<StringArray>())
498 .and_then(|arr| (!arr.is_null(row)).then(|| arr.value(row)));
499
500 Ok((position.is_null(row) || position.value(row) == 0)
501 && (size.is_null(row) || size.value(row) == 0)
502 && blob_uri.unwrap_or("").is_empty())
503 }
504
505 pub async fn scan_stream(
506 ds: &Dataset,
507 projection: Option<&[&str]>,
508 filter: Option<&str>,
509 order_by: Option<Vec<ColumnOrdering>>,
510 with_row_id: bool,
511 ) -> Result<DatasetRecordBatchStream> {
512 Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, |_| Ok(())).await
513 }
514
515 pub async fn scan_stream_with<F>(
516 ds: &Dataset,
517 projection: Option<&[&str]>,
518 filter: Option<&str>,
519 order_by: Option<Vec<ColumnOrdering>>,
520 with_row_id: bool,
521 configure: F,
522 ) -> Result<DatasetRecordBatchStream>
523 where
524 F: FnOnce(&mut Scanner) -> Result<()>,
525 {
526 let mut scanner = ds.scan();
527 if with_row_id {
528 scanner.with_row_id();
529 }
530 if let Some(columns) = projection {
531 scanner
532 .project(columns)
533 .map_err(|e| OmniError::Lance(e.to_string()))?;
534 }
535 if let Some(filter_sql) = filter {
536 scanner
537 .filter(filter_sql)
538 .map_err(|e| OmniError::Lance(e.to_string()))?;
539 }
540 if let Some(ordering) = order_by {
541 scanner
542 .order_by(Some(ordering))
543 .map_err(|e| OmniError::Lance(e.to_string()))?;
544 }
545 configure(&mut scanner)?;
546 scanner
547 .try_into_stream()
548 .await
549 .map_err(|e| OmniError::Lance(e.to_string()))
550 }
551
552 pub async fn scan(
553 &self,
554 ds: &Dataset,
555 projection: Option<&[&str]>,
556 filter: Option<&str>,
557 order_by: Option<Vec<ColumnOrdering>>,
558 ) -> Result<Vec<RecordBatch>> {
559 Self::scan_stream(ds, projection, filter, order_by, false)
560 .await?
561 .try_collect()
562 .await
563 .map_err(|e| OmniError::Lance(e.to_string()))
564 }
565
566 pub async fn scan_with<F>(
567 &self,
568 ds: &Dataset,
569 projection: Option<&[&str]>,
570 filter: Option<&str>,
571 order_by: Option<Vec<ColumnOrdering>>,
572 with_row_id: bool,
573 configure: F,
574 ) -> Result<Vec<RecordBatch>>
575 where
576 F: FnOnce(&mut Scanner) -> Result<()>,
577 {
578 Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, configure)
579 .await?
580 .try_collect()
581 .await
582 .map_err(|e| OmniError::Lance(e.to_string()))
583 }
584
585 pub async fn count_rows(&self, ds: &Dataset, filter: Option<String>) -> Result<usize> {
586 ds.count_rows(filter)
587 .await
588 .map(|count| count as usize)
589 .map_err(|e| OmniError::Lance(e.to_string()))
590 }
591
592 pub fn dataset_version(&self, ds: &Dataset) -> u64 {
593 ds.version().version
594 }
595
596 pub async fn table_state(&self, dataset_uri: &str, ds: &Dataset) -> Result<TableState> {
597 Ok(TableState {
598 version: self.dataset_version(ds),
599 row_count: self.count_rows(ds, None).await? as u64,
600 version_metadata: self.dataset_version_metadata(dataset_uri, ds)?,
601 })
602 }
603
604 pub async fn append_batch(
605 &self,
606 dataset_uri: &str,
607 ds: &mut Dataset,
608 batch: RecordBatch,
609 ) -> Result<TableState> {
610 if batch.num_rows() == 0 {
611 return self.table_state(dataset_uri, ds).await;
612 }
613 let schema = batch.schema();
614 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
615 let params = WriteParams {
616 mode: WriteMode::Append,
617 allow_external_blob_outside_bases: true,
618 ..Default::default()
619 };
620 ds.append(reader, Some(params))
621 .await
622 .map_err(|e| OmniError::Lance(e.to_string()))?;
623 self.table_state(dataset_uri, ds).await
624 }
625
626 pub async fn append_or_create_batch(
627 dataset_uri: &str,
628 dataset: Option<Dataset>,
629 batch: RecordBatch,
630 ) -> Result<Dataset> {
631 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
632 match dataset {
633 Some(mut ds) => {
634 let params = WriteParams {
635 mode: WriteMode::Append,
636 allow_external_blob_outside_bases: true,
637 ..Default::default()
638 };
639 ds.append(reader, Some(params))
640 .await
641 .map_err(|e| OmniError::Lance(e.to_string()))?;
642 Ok(ds)
643 }
644 None => {
645 let params = WriteParams {
646 mode: WriteMode::Create,
647 enable_stable_row_ids: true,
648 data_storage_version: Some(LanceFileVersion::V2_2),
649 allow_external_blob_outside_bases: true,
650 ..Default::default()
651 };
652 Dataset::write(reader, dataset_uri, Some(params))
653 .await
654 .map_err(|e| OmniError::Lance(e.to_string()))
655 }
656 }
657 }
658
659 pub async fn overwrite_batch(
660 &self,
661 dataset_uri: &str,
662 ds: &mut Dataset,
663 batch: RecordBatch,
664 ) -> Result<TableState> {
665 ds.truncate_table()
666 .await
667 .map_err(|e| OmniError::Lance(e.to_string()))?;
668 self.append_batch(dataset_uri, ds, batch).await
669 }
670
671 pub async fn overwrite_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
672 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
673 let params = WriteParams {
674 mode: WriteMode::Overwrite,
675 enable_stable_row_ids: true,
676 data_storage_version: Some(LanceFileVersion::V2_2),
677 allow_external_blob_outside_bases: true,
678 ..Default::default()
679 };
680 Dataset::write(reader, dataset_uri, Some(params))
681 .await
682 .map_err(|e| OmniError::Lance(e.to_string()))
683 }
684
685 pub async fn merge_insert_batch(
686 &self,
687 dataset_uri: &str,
688 ds: Dataset,
689 batch: RecordBatch,
690 key_columns: Vec<String>,
691 when_matched: WhenMatched,
692 when_not_matched: WhenNotMatched,
693 ) -> Result<TableState> {
694 if batch.num_rows() == 0 {
695 return self.table_state(dataset_uri, &ds).await;
696 }
697
698 check_batch_unique_by_keys(&batch, &key_columns, "merge_insert_batch")?;
703
704 let ds = Arc::new(ds);
709 let mut builder = MergeInsertBuilder::try_new(ds, key_columns)
710 .map_err(|e| OmniError::Lance(e.to_string()))?;
711 builder.when_matched(when_matched);
712 builder.when_not_matched(when_not_matched);
713 builder.source_dedupe_behavior(SourceDedupeBehavior::FirstSeen);
749 let job = builder
750 .try_build()
751 .map_err(|e| OmniError::Lance(e.to_string()))?;
752
753 let schema = batch.schema();
754 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
755 let (new_ds, _stats) = job
756 .execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
757 .await
758 .map_err(|e| OmniError::Lance(e.to_string()))?;
759 self.table_state(dataset_uri, &new_ds).await
760 }
761
762 pub async fn merge_insert_batches(
763 &self,
764 dataset_uri: &str,
765 ds: Dataset,
766 batches: Vec<RecordBatch>,
767 key_columns: Vec<String>,
768 when_matched: WhenMatched,
769 when_not_matched: WhenNotMatched,
770 ) -> Result<TableState> {
771 if batches.is_empty() {
772 return self.table_state(dataset_uri, &ds).await;
773 }
774 let batch = if batches.len() == 1 {
775 batches.into_iter().next().unwrap()
776 } else {
777 let schema = batches[0].schema();
778 concat_batches(&schema, &batches).map_err(|e| OmniError::Lance(e.to_string()))?
779 };
780 self.merge_insert_batch(
781 dataset_uri,
782 ds,
783 batch,
784 key_columns,
785 when_matched,
786 when_not_matched,
787 )
788 .await
789 }
790
791 pub async fn delete_where(
792 &self,
793 dataset_uri: &str,
794 ds: &mut Dataset,
795 filter: &str,
796 ) -> Result<DeleteState> {
797 let delete_result = ds
798 .delete(filter)
799 .await
800 .map_err(|e| OmniError::Lance(e.to_string()))?;
801 Ok(DeleteState {
802 version: delete_result.new_dataset.version().version,
803 row_count: self.count_rows(&delete_result.new_dataset, None).await? as u64,
804 deleted_rows: delete_result.num_deleted_rows as usize,
805 version_metadata: self
806 .dataset_version_metadata(dataset_uri, &delete_result.new_dataset)?,
807 })
808 }
809
810 pub async fn stage_append(
859 &self,
860 ds: &Dataset,
861 batch: RecordBatch,
862 prior_stages: &[StagedWrite],
863 ) -> Result<StagedWrite> {
864 if batch.num_rows() == 0 {
865 return Err(OmniError::manifest_internal(
866 "stage_append called with empty batch".to_string(),
867 ));
868 }
869 let params = WriteParams {
870 mode: WriteMode::Append,
871 allow_external_blob_outside_bases: true,
872 ..Default::default()
873 };
874 let transaction = InsertBuilder::new(Arc::new(ds.clone()))
875 .with_params(¶ms)
876 .execute_uncommitted(vec![batch])
877 .await
878 .map_err(|e| OmniError::Lance(e.to_string()))?;
879 let mut new_fragments = match &transaction.operation {
880 Operation::Append { fragments } => fragments.clone(),
881 Operation::Overwrite { fragments, .. } => fragments.clone(),
882 other => {
883 return Err(OmniError::manifest_internal(format!(
884 "stage_append: unexpected Lance operation {:?}",
885 std::mem::discriminant(other)
886 )));
887 }
888 };
889 let next_id_base = ds.manifest.max_fragment_id.unwrap_or(0) as u64
903 + 1
904 + prior_stages_fragment_count(prior_stages);
905 assign_fragment_ids(&mut new_fragments, next_id_base);
906 if ds.manifest.uses_stable_row_ids() {
907 let prior_rows = prior_stages_row_count(prior_stages)?;
908 let start_row_id = ds.manifest.next_row_id + prior_rows;
909 assign_row_id_meta(&mut new_fragments, start_row_id)?;
910 }
911 Ok(StagedWrite {
912 transaction,
913 new_fragments,
914 removed_fragment_ids: Vec::new(),
916 })
917 }
918
919 pub async fn stage_merge_insert(
948 &self,
949 ds: Dataset,
950 batch: RecordBatch,
951 key_columns: Vec<String>,
952 when_matched: WhenMatched,
953 when_not_matched: WhenNotMatched,
954 ) -> Result<StagedWrite> {
955 if batch.num_rows() == 0 {
956 return Err(OmniError::manifest_internal(
957 "stage_merge_insert called with empty batch".to_string(),
958 ));
959 }
960
961 check_batch_unique_by_keys(&batch, &key_columns, "stage_merge_insert")?;
967
968 let ds = Arc::new(ds);
969 let mut builder = MergeInsertBuilder::try_new(ds, key_columns)
970 .map_err(|e| OmniError::Lance(e.to_string()))?;
971 builder.when_matched(when_matched);
972 builder.when_not_matched(when_not_matched);
973 builder.source_dedupe_behavior(SourceDedupeBehavior::FirstSeen);
979 let job = builder
980 .try_build()
981 .map_err(|e| OmniError::Lance(e.to_string()))?;
982 let schema = batch.schema();
983 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
984 let stream = lance_datafusion::utils::reader_to_stream(Box::new(reader));
985 let uncommitted = job
986 .execute_uncommitted(stream)
987 .await
988 .map_err(|e| OmniError::Lance(e.to_string()))?;
989 let (new_fragments, removed_fragment_ids) = match &uncommitted.transaction.operation {
1000 Operation::Update {
1001 new_fragments,
1002 updated_fragments,
1003 removed_fragment_ids,
1004 ..
1005 } => {
1006 let mut all = updated_fragments.clone();
1007 all.extend(new_fragments.iter().cloned());
1008 (all, removed_fragment_ids.clone())
1009 }
1010 Operation::Append { fragments } => (fragments.clone(), Vec::new()),
1011 other => {
1012 return Err(OmniError::manifest_internal(format!(
1013 "stage_merge_insert: unexpected Lance operation {:?}",
1014 std::mem::discriminant(other)
1015 )));
1016 }
1017 };
1018 Ok(StagedWrite {
1019 transaction: uncommitted.transaction,
1020 new_fragments,
1021 removed_fragment_ids,
1022 })
1023 }
1024
1025 pub async fn commit_staged(
1030 &self,
1031 ds: Arc<Dataset>,
1032 transaction: Transaction,
1033 ) -> Result<Dataset> {
1034 CommitBuilder::new(ds)
1035 .execute(transaction)
1036 .await
1037 .map_err(|e| OmniError::Lance(e.to_string()))
1038 }
1039
1040 pub async fn stage_overwrite(&self, ds: &Dataset, batch: RecordBatch) -> Result<StagedWrite> {
1053 if batch.num_rows() == 0 {
1054 return Err(OmniError::manifest_internal(
1055 "stage_overwrite called with empty batch".to_string(),
1056 ));
1057 }
1058 let params = WriteParams {
1069 mode: WriteMode::Overwrite,
1070 enable_stable_row_ids: true,
1071 allow_external_blob_outside_bases: true,
1072 ..Default::default()
1073 };
1074 let transaction = InsertBuilder::new(Arc::new(ds.clone()))
1075 .with_params(¶ms)
1076 .execute_uncommitted(vec![batch])
1077 .await
1078 .map_err(|e| OmniError::Lance(e.to_string()))?;
1079 let mut new_fragments = match &transaction.operation {
1080 Operation::Overwrite { fragments, .. } => fragments.clone(),
1081 other => {
1082 return Err(OmniError::manifest_internal(format!(
1083 "stage_overwrite: unexpected Lance operation {:?}",
1084 std::mem::discriminant(other)
1085 )));
1086 }
1087 };
1088 assign_fragment_ids(&mut new_fragments, 1);
1101 if ds.manifest.uses_stable_row_ids() {
1102 assign_row_id_meta(&mut new_fragments, 0)?;
1103 }
1104 let removed_fragment_ids: Vec<u64> = ds.manifest.fragments.iter().map(|f| f.id).collect();
1109 Ok(StagedWrite {
1110 transaction,
1111 new_fragments,
1112 removed_fragment_ids,
1113 })
1114 }
1115
1116 pub async fn stage_create_btree_index(
1134 &self,
1135 ds: &Dataset,
1136 columns: &[&str],
1137 ) -> Result<StagedWrite> {
1138 let params = ScalarIndexParams::default();
1139 let mut ds_clone = ds.clone();
1140 let new_idx = ds_clone
1141 .create_index_builder(columns, IndexType::BTree, ¶ms)
1142 .replace(true)
1143 .execute_uncommitted()
1144 .await
1145 .map_err(|e| OmniError::Lance(format!("stage_create_btree_index: {}", e)))?;
1146 let removed_indices: Vec<IndexMetadata> = ds
1147 .load_indices()
1148 .await
1149 .map_err(|e| OmniError::Lance(e.to_string()))?
1150 .iter()
1151 .filter(|idx| idx.name == new_idx.name)
1152 .cloned()
1153 .collect();
1154 let transaction = TransactionBuilder::new(
1155 new_idx.dataset_version,
1156 Operation::CreateIndex {
1157 new_indices: vec![new_idx],
1158 removed_indices,
1159 },
1160 )
1161 .build();
1162 Ok(StagedWrite {
1163 transaction,
1164 new_fragments: Vec::new(),
1165 removed_fragment_ids: Vec::new(),
1166 })
1167 }
1168
1169 pub async fn stage_create_inverted_index(
1173 &self,
1174 ds: &Dataset,
1175 column: &str,
1176 ) -> Result<StagedWrite> {
1177 let params = InvertedIndexParams::default();
1178 let mut ds_clone = ds.clone();
1179 let new_idx = ds_clone
1180 .create_index_builder(&[column], IndexType::Inverted, ¶ms)
1181 .replace(true)
1182 .execute_uncommitted()
1183 .await
1184 .map_err(|e| OmniError::Lance(format!("stage_create_inverted_index: {}", e)))?;
1185 let removed_indices: Vec<IndexMetadata> = ds
1186 .load_indices()
1187 .await
1188 .map_err(|e| OmniError::Lance(e.to_string()))?
1189 .iter()
1190 .filter(|idx| idx.name == new_idx.name)
1191 .cloned()
1192 .collect();
1193 let transaction = TransactionBuilder::new(
1194 new_idx.dataset_version,
1195 Operation::CreateIndex {
1196 new_indices: vec![new_idx],
1197 removed_indices,
1198 },
1199 )
1200 .build();
1201 Ok(StagedWrite {
1202 transaction,
1203 new_fragments: Vec::new(),
1204 removed_fragment_ids: Vec::new(),
1205 })
1206 }
1207
1208 pub async fn scan_with_staged(
1237 &self,
1238 ds: &Dataset,
1239 staged: &[StagedWrite],
1240 projection: Option<&[&str]>,
1241 filter: Option<&str>,
1242 ) -> Result<Vec<RecordBatch>> {
1243 if staged.is_empty() {
1244 return self.scan(ds, projection, filter, None).await;
1245 }
1246 let mut scanner = ds.scan();
1247 if let Some(cols) = projection {
1248 let owned: Vec<String> = cols.iter().map(|s| s.to_string()).collect();
1249 scanner
1250 .project(&owned)
1251 .map_err(|e| OmniError::Lance(e.to_string()))?;
1252 }
1253 if let Some(f) = filter {
1254 scanner
1255 .filter(f)
1256 .map_err(|e| OmniError::Lance(e.to_string()))?;
1257 }
1258 scanner.with_fragments(combine_committed_with_staged(ds, staged));
1259 let stream = scanner
1260 .try_into_stream()
1261 .await
1262 .map_err(|e| OmniError::Lance(e.to_string()))?;
1263 stream
1264 .try_collect()
1265 .await
1266 .map_err(|e| OmniError::Lance(e.to_string()))
1267 }
1268
1269 pub async fn scan_with_pending(
1308 &self,
1309 committed_ds: &Dataset,
1310 pending_batches: &[RecordBatch],
1311 pending_schema: Option<SchemaRef>,
1312 projection: Option<&[&str]>,
1313 filter: Option<&str>,
1314 key_column: Option<&str>,
1315 ) -> Result<Vec<RecordBatch>> {
1316 if let (Some(key_col), Some(cols)) = (key_column, projection) {
1325 if !cols.iter().any(|c| *c == key_col) {
1326 return Err(OmniError::Lance(format!(
1327 "scan_with_pending: key_column '{}' must appear in projection \
1328 when merge-shadow semantics are requested (got projection = {:?})",
1329 key_col, cols
1330 )));
1331 }
1332 }
1333
1334 let committed = self.scan(committed_ds, projection, filter, None).await?;
1335 if pending_batches.is_empty() {
1336 return Ok(committed);
1337 }
1338
1339 let committed = match key_column {
1345 Some(key_col) => {
1346 let pending_keys = collect_string_column_values(pending_batches, key_col)?;
1347 if pending_keys.is_empty() {
1348 committed
1349 } else {
1350 filter_out_rows_where_string_in(committed, key_col, &pending_keys)?
1351 }
1352 }
1353 None => committed,
1354 };
1355
1356 let pending =
1357 scan_pending_batches(pending_batches, pending_schema, projection, filter).await?;
1358
1359 let mut out = committed;
1360 out.extend(pending);
1361 Ok(out)
1362 }
1363
1364 pub async fn count_rows_with_staged(
1369 &self,
1370 ds: &Dataset,
1371 staged: &[StagedWrite],
1372 filter: Option<String>,
1373 ) -> Result<usize> {
1374 if staged.is_empty() {
1375 return self.count_rows(ds, filter).await;
1376 }
1377 let mut scanner = ds.scan();
1378 if let Some(f) = filter {
1379 scanner
1380 .filter(&f)
1381 .map_err(|e| OmniError::Lance(e.to_string()))?;
1382 }
1383 scanner.with_fragments(combine_committed_with_staged(ds, staged));
1384 let count = scanner
1385 .count_rows()
1386 .await
1387 .map_err(|e| OmniError::Lance(e.to_string()))?;
1388 Ok(count as usize)
1389 }
1390
1391 async fn user_indices_for_column(
1392 &self,
1393 ds: &Dataset,
1394 column: &str,
1395 ) -> Result<Vec<IndexMetadata>> {
1396 let field_id = ds
1397 .schema()
1398 .field(column)
1399 .map(|field| field.id)
1400 .ok_or_else(|| {
1401 OmniError::manifest_internal(format!(
1402 "dataset is missing expected index column '{}'",
1403 column
1404 ))
1405 })?;
1406 let indices = ds
1407 .load_indices()
1408 .await
1409 .map_err(|e| OmniError::Lance(e.to_string()))?;
1410 Ok(indices
1411 .iter()
1412 .filter(|index| !is_system_index(index))
1413 .filter(|index| index.fields.len() == 1 && index.fields[0] == field_id)
1414 .cloned()
1415 .collect())
1416 }
1417
1418 pub async fn has_btree_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
1419 let indices = self.user_indices_for_column(ds, column).await?;
1420 Ok(indices.iter().any(|index| {
1421 index
1422 .index_details
1423 .as_ref()
1424 .map(|details| details.type_url.ends_with("BTreeIndexDetails"))
1425 .unwrap_or(false)
1426 }))
1427 }
1428
1429 pub async fn has_fts_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
1430 let indices = self.user_indices_for_column(ds, column).await?;
1431 Ok(indices.iter().any(|index| {
1432 index
1433 .index_details
1434 .as_ref()
1435 .map(|details| IndexDetails(details.clone()).supports_fts())
1436 .unwrap_or(false)
1437 }))
1438 }
1439
1440 pub async fn has_vector_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
1441 let indices = self.user_indices_for_column(ds, column).await?;
1442 Ok(indices.iter().any(|index| {
1443 index
1444 .index_details
1445 .as_ref()
1446 .map(|details| IndexDetails(details.clone()).is_vector())
1447 .unwrap_or(false)
1448 }))
1449 }
1450
1451 pub async fn create_btree_index(&self, ds: &mut Dataset, columns: &[&str]) -> Result<()> {
1452 let params = ScalarIndexParams::default();
1453 ds.create_index_builder(columns, IndexType::BTree, ¶ms)
1454 .replace(true)
1455 .await
1456 .map(|_| ())
1457 .map_err(|e| OmniError::Lance(e.to_string()))
1458 }
1459
1460 pub async fn create_inverted_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
1461 let params = InvertedIndexParams::default();
1462 ds.create_index_builder(&[column], IndexType::Inverted, ¶ms)
1463 .replace(true)
1464 .await
1465 .map(|_| ())
1466 .map_err(|e| OmniError::Lance(e.to_string()))
1467 }
1468
1469 pub async fn create_vector_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
1470 let params = lance::index::vector::VectorIndexParams::ivf_flat(1, MetricType::L2);
1471 ds.create_index_builder(&[column], IndexType::Vector, ¶ms)
1472 .replace(true)
1473 .await
1474 .map(|_| ())
1475 .map_err(|e| OmniError::Lance(e.to_string()))
1476 }
1477
1478 pub async fn create_empty_dataset(dataset_uri: &str, schema: &SchemaRef) -> Result<Dataset> {
1479 let batch = RecordBatch::new_empty(schema.clone());
1480 Self::write_dataset(dataset_uri, batch).await
1481 }
1482
1483 pub async fn first_row_id_for_filter(&self, ds: &Dataset, filter: &str) -> Result<Option<u64>> {
1484 let batches = Self::scan_stream(ds, Some(&["id"]), Some(filter), None, true)
1485 .await?
1486 .try_collect::<Vec<RecordBatch>>()
1487 .await
1488 .map_err(|e| OmniError::Lance(e.to_string()))?;
1489 Ok(batches.iter().find_map(|batch| {
1490 batch
1491 .column_by_name("_rowid")
1492 .and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
1493 .and_then(|arr| (arr.len() > 0).then(|| arr.value(0)))
1494 }))
1495 }
1496
1497 pub async fn write_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
1498 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
1499 let params = WriteParams {
1500 mode: WriteMode::Create,
1501 enable_stable_row_ids: true,
1502 data_storage_version: Some(LanceFileVersion::V2_2),
1503 allow_external_blob_outside_bases: true,
1504 ..Default::default()
1505 };
1506 Dataset::write(reader, dataset_uri, Some(params))
1507 .await
1508 .map_err(|e| OmniError::Lance(e.to_string()))
1509 }
1510}
1511
1512fn prior_stages_fragment_count(prior_stages: &[StagedWrite]) -> u64 {
1546 prior_stages
1547 .iter()
1548 .map(|s| s.new_fragments.len() as u64)
1549 .sum()
1550}
1551
1552fn assign_fragment_ids(fragments: &mut [Fragment], start_id: u64) {
1560 for (i, fragment) in fragments.iter_mut().enumerate() {
1561 if fragment.id == 0 {
1562 fragment.id = start_id + i as u64;
1563 }
1564 }
1565}
1566
1567fn prior_stages_row_count(prior_stages: &[StagedWrite]) -> Result<u64> {
1568 let mut total: u64 = 0;
1569 for stage in prior_stages {
1570 for fragment in &stage.new_fragments {
1571 let physical_rows = fragment.physical_rows.ok_or_else(|| {
1572 OmniError::manifest_internal(
1573 "prior_stages_row_count: fragment is missing physical_rows".to_string(),
1574 )
1575 })? as u64;
1576 total += physical_rows;
1577 }
1578 }
1579 Ok(total)
1580}
1581
1582fn assign_row_id_meta(fragments: &mut [Fragment], start_row_id: u64) -> Result<()> {
1592 let mut next_row_id = start_row_id;
1593 for fragment in fragments {
1594 if fragment.row_id_meta.is_some() {
1595 continue;
1596 }
1597 let physical_rows = fragment.physical_rows.ok_or_else(|| {
1598 OmniError::manifest_internal(
1599 "stage_append: fragment is missing physical_rows".to_string(),
1600 )
1601 })? as u64;
1602 let row_ids = next_row_id..(next_row_id + physical_rows);
1603 let sequence = RowIdSequence::from(row_ids);
1604 let serialized = write_row_ids(&sequence);
1605 fragment.row_id_meta = Some(RowIdMeta::Inline(serialized));
1606 next_row_id += physical_rows;
1607 }
1608 Ok(())
1609}
1610
1611fn collect_string_column_values(
1616 batches: &[RecordBatch],
1617 column: &str,
1618) -> Result<std::collections::HashSet<String>> {
1619 use arrow_array::{Array, StringArray};
1620 let mut out = std::collections::HashSet::new();
1621 for batch in batches {
1622 let Some(col) = batch.column_by_name(column) else {
1623 return Err(OmniError::Lance(format!(
1624 "scan_with_pending: pending batch missing key column '{}'",
1625 column
1626 )));
1627 };
1628 let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
1629 OmniError::Lance(format!(
1630 "scan_with_pending: key column '{}' is not Utf8",
1631 column
1632 ))
1633 })?;
1634 for i in 0..arr.len() {
1635 if arr.is_valid(i) {
1636 out.insert(arr.value(i).to_string());
1637 }
1638 }
1639 }
1640 Ok(out)
1641}
1642
1643fn filter_out_rows_where_string_in(
1652 batches: Vec<RecordBatch>,
1653 column: &str,
1654 excluded: &std::collections::HashSet<String>,
1655) -> Result<Vec<RecordBatch>> {
1656 use arrow_array::{Array, BooleanArray, StringArray};
1657 let mut out = Vec::with_capacity(batches.len());
1658 for batch in batches {
1659 if batch.num_rows() == 0 {
1660 out.push(batch);
1661 continue;
1662 }
1663 let col = batch.column_by_name(column).ok_or_else(|| {
1664 OmniError::manifest_internal(format!(
1665 "scan_with_pending: committed batch missing key column '{}' \
1666 (the up-front projection check should have rejected this)",
1667 column
1668 ))
1669 })?;
1670 let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
1671 OmniError::Lance(format!(
1672 "scan_with_pending: committed column '{}' is not Utf8",
1673 column
1674 ))
1675 })?;
1676 let mask: BooleanArray = (0..arr.len())
1677 .map(|i| {
1678 if arr.is_valid(i) {
1679 Some(!excluded.contains(arr.value(i)))
1680 } else {
1681 Some(true)
1682 }
1683 })
1684 .collect();
1685 let filtered = arrow_select::filter::filter_record_batch(&batch, &mask)
1686 .map_err(|e| OmniError::Lance(e.to_string()))?;
1687 out.push(filtered);
1688 }
1689 Ok(out)
1690}
1691
1692async fn scan_pending_batches(
1708 pending_batches: &[RecordBatch],
1709 pending_schema: Option<SchemaRef>,
1710 projection: Option<&[&str]>,
1711 filter: Option<&str>,
1712) -> Result<Vec<RecordBatch>> {
1713 let schema = pending_schema.unwrap_or_else(|| pending_batches[0].schema());
1714 let ctx = datafusion::execution::context::SessionContext::new();
1715 let mem = datafusion::datasource::MemTable::try_new(schema, vec![pending_batches.to_vec()])
1716 .map_err(|e| OmniError::Lance(e.to_string()))?;
1717 ctx.register_table("pending", Arc::new(mem))
1718 .map_err(|e| OmniError::Lance(e.to_string()))?;
1719
1720 let proj = projection
1721 .map(|cols| {
1722 cols.iter()
1723 .map(|c| format!("\"{}\"", c.replace('"', "\"\"")))
1724 .collect::<Vec<_>>()
1725 .join(", ")
1726 })
1727 .unwrap_or_else(|| "*".to_string());
1728 let where_clause = filter.map(|f| format!("WHERE {f}")).unwrap_or_default();
1729 let sql = format!("SELECT {proj} FROM pending {where_clause}");
1730 let df = ctx
1731 .sql(&sql)
1732 .await
1733 .map_err(|e| OmniError::Lance(e.to_string()))?;
1734 df.collect()
1735 .await
1736 .map_err(|e| OmniError::Lance(e.to_string()))
1737}
1738
1739fn combine_committed_with_staged(ds: &Dataset, staged: &[StagedWrite]) -> Vec<Fragment> {
1740 let removed: std::collections::HashSet<u64> = staged
1741 .iter()
1742 .flat_map(|w| w.removed_fragment_ids.iter().copied())
1743 .collect();
1744 let mut combined: Vec<Fragment> = ds
1745 .manifest
1746 .fragments
1747 .iter()
1748 .filter(|f| !removed.contains(&f.id))
1749 .cloned()
1750 .collect();
1751 for write in staged {
1752 combined.extend(write.new_fragments.iter().cloned());
1753 }
1754 combined
1755}
1756
1757fn check_batch_unique_by_keys(
1769 batch: &RecordBatch,
1770 key_columns: &[String],
1771 context: &'static str,
1772) -> Result<()> {
1773 if key_columns.len() != 1 {
1774 return Err(OmniError::manifest_internal(format!(
1775 "{}: check_batch_unique_by_keys currently supports single-column keys only, got {:?}",
1776 context, key_columns
1777 )));
1778 }
1779 let key_col_name = &key_columns[0];
1780 let column = batch.column_by_name(key_col_name).ok_or_else(|| {
1781 OmniError::manifest_internal(format!(
1782 "{}: source batch missing key column '{}'",
1783 context, key_col_name
1784 ))
1785 })?;
1786 let strs = column
1787 .as_any()
1788 .downcast_ref::<StringArray>()
1789 .ok_or_else(|| {
1790 OmniError::manifest_internal(format!(
1791 "{}: key column '{}' is not a StringArray (got {:?})",
1792 context,
1793 key_col_name,
1794 column.data_type()
1795 ))
1796 })?;
1797
1798 let mut seen: std::collections::HashSet<&str> =
1799 std::collections::HashSet::with_capacity(batch.num_rows());
1800 for i in 0..strs.len() {
1801 if !strs.is_valid(i) {
1802 continue;
1803 }
1804 let v = strs.value(i);
1805 if !seen.insert(v) {
1806 return Err(OmniError::manifest(format!(
1807 "{}: duplicate source row for key '{}' (column '{}'); \
1808 callers must hand in a batch unique by `key_columns` \
1809 — see MR-957",
1810 context, v, key_col_name
1811 )));
1812 }
1813 }
1814 Ok(())
1815}
1816
1817#[cfg(test)]
1818mod tests {
1819 use super::*;
1820 use arrow_array::StringArray;
1821 use arrow_schema::{DataType, Field, Schema};
1822
1823 fn batch_with_ids(ids: &[&str]) -> RecordBatch {
1824 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
1825 let col = Arc::new(StringArray::from(ids.to_vec())) as ArrayRef;
1826 RecordBatch::try_new(schema, vec![col]).unwrap()
1827 }
1828
1829 #[test]
1830 fn check_batch_unique_by_keys_passes_when_all_unique() {
1831 let batch = batch_with_ids(&["a", "b", "c"]);
1832 check_batch_unique_by_keys(&batch, &["id".to_string()], "test").unwrap();
1833 }
1834
1835 #[test]
1836 fn check_batch_unique_by_keys_errors_on_duplicate_id() {
1837 let batch = batch_with_ids(&["a", "b", "a"]);
1838 let err = check_batch_unique_by_keys(&batch, &["id".to_string()], "test").unwrap_err();
1839 let msg = err.to_string();
1840 assert!(
1841 msg.contains("duplicate source row for key 'a'"),
1842 "unexpected error: {msg}"
1843 );
1844 assert!(
1845 msg.contains("MR-957"),
1846 "error should reference MR-957: {msg}"
1847 );
1848 }
1849
1850 #[test]
1851 fn check_batch_unique_by_keys_rejects_multi_column_keys() {
1852 let batch = batch_with_ids(&["a"]);
1853 let err =
1854 check_batch_unique_by_keys(&batch, &["id".to_string(), "other".to_string()], "test")
1855 .unwrap_err();
1856 assert!(err.to_string().contains("single-column keys only"));
1857 }
1858}