1use arrow_array::{RecordBatch, UInt64Array};
2use arrow_schema::SchemaRef;
3use arrow_select::concat::concat_batches;
4use futures::TryStreamExt;
5use lance::Dataset;
6use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner};
7use lance::dataset::transaction::{Operation, Transaction, TransactionBuilder};
8use lance::dataset::{
9 CommitBuilder, InsertBuilder, MergeInsertBuilder, WhenMatched, WhenNotMatched, WriteMode,
10 WriteParams,
11};
12use lance::datatypes::BlobHandling;
13use lance::index::scalar::IndexDetails;
14use lance_file::version::LanceFileVersion;
15use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams};
16use lance_index::{DatasetIndexExt, IndexType, is_system_index};
17use lance_linalg::distance::MetricType;
18use lance_table::format::{Fragment, IndexMetadata, RowIdMeta};
19use lance_table::rowids::{RowIdSequence, write_row_ids};
20use std::sync::Arc;
21
22use crate::db::manifest::{TableVersionMetadata, open_table_head_for_write};
23use crate::db::{Snapshot, SubTableEntry};
24use crate::error::{OmniError, Result};
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct TableState {
28 pub version: u64,
29 pub row_count: u64,
30 pub(crate) version_metadata: TableVersionMetadata,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct DeleteState {
35 pub version: u64,
36 pub row_count: u64,
37 pub deleted_rows: usize,
38 pub(crate) version_metadata: TableVersionMetadata,
39}
40
41#[derive(Debug, Clone)]
63pub struct StagedWrite {
64 pub transaction: Transaction,
65 pub new_fragments: Vec<Fragment>,
72 pub removed_fragment_ids: Vec<u64>,
79}
80
81#[derive(Debug, Clone)]
82pub struct TableStore {
83 root_uri: String,
84}
85
86impl TableStore {
87 pub fn new(root_uri: &str) -> Self {
88 Self {
89 root_uri: root_uri.trim_end_matches('/').to_string(),
90 }
91 }
92
93 pub fn root_uri(&self) -> &str {
94 &self.root_uri
95 }
96
97 pub fn dataset_uri(&self, table_path: &str) -> String {
98 format!("{}/{}", self.root_uri, table_path)
99 }
100
101 fn table_path_from_dataset_uri(&self, dataset_uri: &str) -> Result<String> {
102 let prefix = format!("{}/", self.root_uri.trim_end_matches('/'));
103 let table_path = dataset_uri
104 .strip_prefix(&prefix)
105 .map(|path| path.to_string())
106 .ok_or_else(|| {
107 OmniError::manifest_internal(format!(
108 "dataset uri '{}' is not under root '{}'",
109 dataset_uri, self.root_uri
110 ))
111 })?;
112 Ok(table_path
113 .split_once("/tree/")
114 .map(|(path, _)| path.to_string())
115 .unwrap_or(table_path))
116 }
117
118 fn dataset_version_metadata(
119 &self,
120 dataset_uri: &str,
121 ds: &Dataset,
122 ) -> Result<TableVersionMetadata> {
123 let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
124 TableVersionMetadata::from_dataset(&self.root_uri, &table_path, ds)
125 }
126
127 pub async fn open_snapshot_table(
128 &self,
129 snapshot: &Snapshot,
130 table_key: &str,
131 ) -> Result<Dataset> {
132 snapshot.open(table_key).await
133 }
134
135 pub async fn open_at_entry(&self, entry: &SubTableEntry) -> Result<Dataset> {
136 entry.open(&self.root_uri).await
137 }
138
139 pub async fn open_dataset_head(
140 &self,
141 dataset_uri: &str,
142 branch: Option<&str>,
143 ) -> Result<Dataset> {
144 let ds = Dataset::open(dataset_uri)
145 .await
146 .map_err(|e| OmniError::Lance(e.to_string()))?;
147 match branch {
148 Some(branch) if branch != "main" => ds
149 .checkout_branch(branch)
150 .await
151 .map_err(|e| OmniError::Lance(e.to_string())),
152 _ => Ok(ds),
153 }
154 }
155
156 pub async fn open_dataset_head_for_write(
157 &self,
158 table_key: &str,
159 dataset_uri: &str,
160 branch: Option<&str>,
161 ) -> Result<Dataset> {
162 let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
163 open_table_head_for_write(&self.root_uri, table_key, &table_path, branch).await
164 }
165
166 pub async fn delete_branch(&self, dataset_uri: &str, branch: &str) -> Result<()> {
167 let mut ds = Dataset::open(dataset_uri)
168 .await
169 .map_err(|e| OmniError::Lance(e.to_string()))?;
170 ds.delete_branch(branch)
171 .await
172 .map_err(|e| OmniError::Lance(e.to_string()))
173 }
174
175 pub async fn open_dataset_at_state(
176 &self,
177 table_path: &str,
178 branch: Option<&str>,
179 version: u64,
180 ) -> Result<Dataset> {
181 let ds = self
182 .open_dataset_head(&self.dataset_uri(table_path), branch)
183 .await?;
184 ds.checkout_version(version)
185 .await
186 .map_err(|e| OmniError::Lance(e.to_string()))
187 }
188
189 pub fn ensure_expected_version(
190 &self,
191 ds: &Dataset,
192 table_key: &str,
193 expected_version: u64,
194 ) -> Result<()> {
195 let actual = ds.version().version;
196 if actual != expected_version {
197 return Err(OmniError::manifest_expected_version_mismatch(
203 table_key,
204 expected_version,
205 actual,
206 ));
207 }
208 Ok(())
209 }
210
211 pub async fn reopen_for_mutation(
212 &self,
213 dataset_uri: &str,
214 branch: Option<&str>,
215 table_key: &str,
216 expected_version: u64,
217 ) -> Result<Dataset> {
218 let ds = self
219 .open_dataset_head_for_write(table_key, dataset_uri, branch)
220 .await?;
221 self.ensure_expected_version(&ds, table_key, expected_version)?;
222 Ok(ds)
223 }
224
225 pub async fn fork_branch_from_state(
226 &self,
227 dataset_uri: &str,
228 source_branch: Option<&str>,
229 table_key: &str,
230 source_version: u64,
231 target_branch: &str,
232 ) -> Result<Dataset> {
233 let mut source_ds = self
234 .open_dataset_head(dataset_uri, source_branch)
235 .await?
236 .checkout_version(source_version)
237 .await
238 .map_err(|e| OmniError::Lance(e.to_string()))?;
239 self.ensure_expected_version(&source_ds, table_key, source_version)?;
240
241 match source_ds
242 .create_branch(target_branch, source_version, None)
243 .await
244 {
245 Ok(_) => {}
246 Err(create_err) => match self
247 .open_dataset_head(dataset_uri, Some(target_branch))
248 .await
249 {
250 Ok(ds) => {
251 self.ensure_expected_version(&ds, table_key, source_version)?;
252 return Ok(ds);
253 }
254 Err(_) => return Err(OmniError::Lance(create_err.to_string())),
255 },
256 }
257
258 let ds = self
259 .open_dataset_head(dataset_uri, Some(target_branch))
260 .await?;
261 self.ensure_expected_version(&ds, table_key, source_version)?;
262 Ok(ds)
263 }
264
265 pub async fn scan_batches(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
266 self.scan(ds, None, None, None).await
267 }
268
269 pub async fn scan_batches_for_rewrite(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
270 let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
271 if !has_blob_columns {
272 return self.scan_batches(ds).await;
273 }
274
275 let mut scanner = ds.scan();
276 scanner.blob_handling(BlobHandling::AllBinary);
277 scanner
278 .try_into_stream()
279 .await
280 .map_err(|e| OmniError::Lance(e.to_string()))?
281 .try_collect()
282 .await
283 .map_err(|e| OmniError::Lance(e.to_string()))
284 }
285
286 pub async fn scan_stream(
287 ds: &Dataset,
288 projection: Option<&[&str]>,
289 filter: Option<&str>,
290 order_by: Option<Vec<ColumnOrdering>>,
291 with_row_id: bool,
292 ) -> Result<DatasetRecordBatchStream> {
293 Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, |_| Ok(())).await
294 }
295
296 pub async fn scan_stream_with<F>(
297 ds: &Dataset,
298 projection: Option<&[&str]>,
299 filter: Option<&str>,
300 order_by: Option<Vec<ColumnOrdering>>,
301 with_row_id: bool,
302 configure: F,
303 ) -> Result<DatasetRecordBatchStream>
304 where
305 F: FnOnce(&mut Scanner) -> Result<()>,
306 {
307 let mut scanner = ds.scan();
308 if with_row_id {
309 scanner.with_row_id();
310 }
311 if let Some(columns) = projection {
312 scanner
313 .project(columns)
314 .map_err(|e| OmniError::Lance(e.to_string()))?;
315 }
316 if let Some(filter_sql) = filter {
317 scanner
318 .filter(filter_sql)
319 .map_err(|e| OmniError::Lance(e.to_string()))?;
320 }
321 if let Some(ordering) = order_by {
322 scanner
323 .order_by(Some(ordering))
324 .map_err(|e| OmniError::Lance(e.to_string()))?;
325 }
326 configure(&mut scanner)?;
327 scanner
328 .try_into_stream()
329 .await
330 .map_err(|e| OmniError::Lance(e.to_string()))
331 }
332
333 pub async fn scan(
334 &self,
335 ds: &Dataset,
336 projection: Option<&[&str]>,
337 filter: Option<&str>,
338 order_by: Option<Vec<ColumnOrdering>>,
339 ) -> Result<Vec<RecordBatch>> {
340 Self::scan_stream(ds, projection, filter, order_by, false)
341 .await?
342 .try_collect()
343 .await
344 .map_err(|e| OmniError::Lance(e.to_string()))
345 }
346
347 pub async fn scan_with<F>(
348 &self,
349 ds: &Dataset,
350 projection: Option<&[&str]>,
351 filter: Option<&str>,
352 order_by: Option<Vec<ColumnOrdering>>,
353 with_row_id: bool,
354 configure: F,
355 ) -> Result<Vec<RecordBatch>>
356 where
357 F: FnOnce(&mut Scanner) -> Result<()>,
358 {
359 Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, configure)
360 .await?
361 .try_collect()
362 .await
363 .map_err(|e| OmniError::Lance(e.to_string()))
364 }
365
366 pub async fn count_rows(&self, ds: &Dataset, filter: Option<String>) -> Result<usize> {
367 ds.count_rows(filter)
368 .await
369 .map(|count| count as usize)
370 .map_err(|e| OmniError::Lance(e.to_string()))
371 }
372
373 pub fn dataset_version(&self, ds: &Dataset) -> u64 {
374 ds.version().version
375 }
376
377 pub async fn table_state(&self, dataset_uri: &str, ds: &Dataset) -> Result<TableState> {
378 Ok(TableState {
379 version: self.dataset_version(ds),
380 row_count: self.count_rows(ds, None).await? as u64,
381 version_metadata: self.dataset_version_metadata(dataset_uri, ds)?,
382 })
383 }
384
385 pub async fn append_batch(
386 &self,
387 dataset_uri: &str,
388 ds: &mut Dataset,
389 batch: RecordBatch,
390 ) -> Result<TableState> {
391 if batch.num_rows() == 0 {
392 return self.table_state(dataset_uri, ds).await;
393 }
394 let schema = batch.schema();
395 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
396 let params = WriteParams {
397 mode: WriteMode::Append,
398 allow_external_blob_outside_bases: true,
399 ..Default::default()
400 };
401 ds.append(reader, Some(params))
402 .await
403 .map_err(|e| OmniError::Lance(e.to_string()))?;
404 self.table_state(dataset_uri, ds).await
405 }
406
407 pub async fn append_or_create_batch(
408 dataset_uri: &str,
409 dataset: Option<Dataset>,
410 batch: RecordBatch,
411 ) -> Result<Dataset> {
412 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
413 match dataset {
414 Some(mut ds) => {
415 let params = WriteParams {
416 mode: WriteMode::Append,
417 allow_external_blob_outside_bases: true,
418 ..Default::default()
419 };
420 ds.append(reader, Some(params))
421 .await
422 .map_err(|e| OmniError::Lance(e.to_string()))?;
423 Ok(ds)
424 }
425 None => {
426 let params = WriteParams {
427 mode: WriteMode::Create,
428 enable_stable_row_ids: true,
429 data_storage_version: Some(LanceFileVersion::V2_2),
430 allow_external_blob_outside_bases: true,
431 ..Default::default()
432 };
433 Dataset::write(reader, dataset_uri, Some(params))
434 .await
435 .map_err(|e| OmniError::Lance(e.to_string()))
436 }
437 }
438 }
439
440 pub async fn overwrite_batch(
441 &self,
442 dataset_uri: &str,
443 ds: &mut Dataset,
444 batch: RecordBatch,
445 ) -> Result<TableState> {
446 ds.truncate_table()
447 .await
448 .map_err(|e| OmniError::Lance(e.to_string()))?;
449 self.append_batch(dataset_uri, ds, batch).await
450 }
451
452 pub async fn overwrite_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
453 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
454 let params = WriteParams {
455 mode: WriteMode::Overwrite,
456 enable_stable_row_ids: true,
457 data_storage_version: Some(LanceFileVersion::V2_2),
458 allow_external_blob_outside_bases: true,
459 ..Default::default()
460 };
461 Dataset::write(reader, dataset_uri, Some(params))
462 .await
463 .map_err(|e| OmniError::Lance(e.to_string()))
464 }
465
466 pub async fn merge_insert_batch(
467 &self,
468 dataset_uri: &str,
469 ds: Dataset,
470 batch: RecordBatch,
471 key_columns: Vec<String>,
472 when_matched: WhenMatched,
473 when_not_matched: WhenNotMatched,
474 ) -> Result<TableState> {
475 if batch.num_rows() == 0 {
476 return self.table_state(dataset_uri, &ds).await;
477 }
478
479 let ds = Arc::new(ds);
484 let job = MergeInsertBuilder::try_new(ds, key_columns)
485 .map_err(|e| OmniError::Lance(e.to_string()))?
486 .when_matched(when_matched)
487 .when_not_matched(when_not_matched)
488 .try_build()
489 .map_err(|e| OmniError::Lance(e.to_string()))?;
490
491 let schema = batch.schema();
492 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
493 let (new_ds, _stats) = job
494 .execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
495 .await
496 .map_err(|e| OmniError::Lance(e.to_string()))?;
497 self.table_state(dataset_uri, &new_ds).await
498 }
499
500 pub async fn merge_insert_batches(
501 &self,
502 dataset_uri: &str,
503 ds: Dataset,
504 batches: Vec<RecordBatch>,
505 key_columns: Vec<String>,
506 when_matched: WhenMatched,
507 when_not_matched: WhenNotMatched,
508 ) -> Result<TableState> {
509 if batches.is_empty() {
510 return self.table_state(dataset_uri, &ds).await;
511 }
512 let batch = if batches.len() == 1 {
513 batches.into_iter().next().unwrap()
514 } else {
515 let schema = batches[0].schema();
516 concat_batches(&schema, &batches).map_err(|e| OmniError::Lance(e.to_string()))?
517 };
518 self.merge_insert_batch(
519 dataset_uri,
520 ds,
521 batch,
522 key_columns,
523 when_matched,
524 when_not_matched,
525 )
526 .await
527 }
528
529 pub async fn delete_where(
530 &self,
531 dataset_uri: &str,
532 ds: &mut Dataset,
533 filter: &str,
534 ) -> Result<DeleteState> {
535 let delete_result = ds
536 .delete(filter)
537 .await
538 .map_err(|e| OmniError::Lance(e.to_string()))?;
539 Ok(DeleteState {
540 version: delete_result.new_dataset.version().version,
541 row_count: self.count_rows(&delete_result.new_dataset, None).await? as u64,
542 deleted_rows: delete_result.num_deleted_rows as usize,
543 version_metadata: self
544 .dataset_version_metadata(dataset_uri, &delete_result.new_dataset)?,
545 })
546 }
547
548 pub async fn stage_append(
597 &self,
598 ds: &Dataset,
599 batch: RecordBatch,
600 prior_stages: &[StagedWrite],
601 ) -> Result<StagedWrite> {
602 if batch.num_rows() == 0 {
603 return Err(OmniError::manifest_internal(
604 "stage_append called with empty batch".to_string(),
605 ));
606 }
607 let params = WriteParams {
608 mode: WriteMode::Append,
609 allow_external_blob_outside_bases: true,
610 ..Default::default()
611 };
612 let transaction = InsertBuilder::new(Arc::new(ds.clone()))
613 .with_params(¶ms)
614 .execute_uncommitted(vec![batch])
615 .await
616 .map_err(|e| OmniError::Lance(e.to_string()))?;
617 let mut new_fragments = match &transaction.operation {
618 Operation::Append { fragments } => fragments.clone(),
619 Operation::Overwrite { fragments, .. } => fragments.clone(),
620 other => {
621 return Err(OmniError::manifest_internal(format!(
622 "stage_append: unexpected Lance operation {:?}",
623 std::mem::discriminant(other)
624 )));
625 }
626 };
627 let next_id_base = ds.manifest.max_fragment_id.unwrap_or(0) as u64
641 + 1
642 + prior_stages_fragment_count(prior_stages);
643 assign_fragment_ids(&mut new_fragments, next_id_base);
644 if ds.manifest.uses_stable_row_ids() {
645 let prior_rows = prior_stages_row_count(prior_stages)?;
646 let start_row_id = ds.manifest.next_row_id + prior_rows;
647 assign_row_id_meta(&mut new_fragments, start_row_id)?;
648 }
649 Ok(StagedWrite {
650 transaction,
651 new_fragments,
652 removed_fragment_ids: Vec::new(),
654 })
655 }
656
657 pub async fn stage_merge_insert(
686 &self,
687 ds: Dataset,
688 batch: RecordBatch,
689 key_columns: Vec<String>,
690 when_matched: WhenMatched,
691 when_not_matched: WhenNotMatched,
692 ) -> Result<StagedWrite> {
693 if batch.num_rows() == 0 {
694 return Err(OmniError::manifest_internal(
695 "stage_merge_insert called with empty batch".to_string(),
696 ));
697 }
698 let ds = Arc::new(ds);
699 let job = MergeInsertBuilder::try_new(ds, key_columns)
700 .map_err(|e| OmniError::Lance(e.to_string()))?
701 .when_matched(when_matched)
702 .when_not_matched(when_not_matched)
703 .try_build()
704 .map_err(|e| OmniError::Lance(e.to_string()))?;
705 let schema = batch.schema();
706 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
707 let stream = lance_datafusion::utils::reader_to_stream(Box::new(reader));
708 let uncommitted = job
709 .execute_uncommitted(stream)
710 .await
711 .map_err(|e| OmniError::Lance(e.to_string()))?;
712 let (new_fragments, removed_fragment_ids) = match &uncommitted.transaction.operation {
723 Operation::Update {
724 new_fragments,
725 updated_fragments,
726 removed_fragment_ids,
727 ..
728 } => {
729 let mut all = updated_fragments.clone();
730 all.extend(new_fragments.iter().cloned());
731 (all, removed_fragment_ids.clone())
732 }
733 Operation::Append { fragments } => (fragments.clone(), Vec::new()),
734 other => {
735 return Err(OmniError::manifest_internal(format!(
736 "stage_merge_insert: unexpected Lance operation {:?}",
737 std::mem::discriminant(other)
738 )));
739 }
740 };
741 Ok(StagedWrite {
742 transaction: uncommitted.transaction,
743 new_fragments,
744 removed_fragment_ids,
745 })
746 }
747
748 pub async fn commit_staged(
753 &self,
754 ds: Arc<Dataset>,
755 transaction: Transaction,
756 ) -> Result<Dataset> {
757 CommitBuilder::new(ds)
758 .execute(transaction)
759 .await
760 .map_err(|e| OmniError::Lance(e.to_string()))
761 }
762
763 pub async fn stage_overwrite(
776 &self,
777 ds: &Dataset,
778 batch: RecordBatch,
779 ) -> Result<StagedWrite> {
780 if batch.num_rows() == 0 {
781 return Err(OmniError::manifest_internal(
782 "stage_overwrite called with empty batch".to_string(),
783 ));
784 }
785 let params = WriteParams {
786 mode: WriteMode::Overwrite,
787 allow_external_blob_outside_bases: true,
788 ..Default::default()
789 };
790 let transaction = InsertBuilder::new(Arc::new(ds.clone()))
791 .with_params(¶ms)
792 .execute_uncommitted(vec![batch])
793 .await
794 .map_err(|e| OmniError::Lance(e.to_string()))?;
795 let mut new_fragments = match &transaction.operation {
796 Operation::Overwrite { fragments, .. } => fragments.clone(),
797 other => {
798 return Err(OmniError::manifest_internal(format!(
799 "stage_overwrite: unexpected Lance operation {:?}",
800 std::mem::discriminant(other)
801 )));
802 }
803 };
804 assign_fragment_ids(&mut new_fragments, 1);
817 if ds.manifest.uses_stable_row_ids() {
818 assign_row_id_meta(&mut new_fragments, 0)?;
819 }
820 let removed_fragment_ids: Vec<u64> =
825 ds.manifest.fragments.iter().map(|f| f.id).collect();
826 Ok(StagedWrite {
827 transaction,
828 new_fragments,
829 removed_fragment_ids,
830 })
831 }
832
833 pub async fn stage_create_btree_index(
851 &self,
852 ds: &Dataset,
853 columns: &[&str],
854 ) -> Result<StagedWrite> {
855 let params = ScalarIndexParams::default();
856 let mut ds_clone = ds.clone();
857 let new_idx = ds_clone
858 .create_index_builder(columns, IndexType::BTree, ¶ms)
859 .replace(true)
860 .execute_uncommitted()
861 .await
862 .map_err(|e| {
863 OmniError::Lance(format!("stage_create_btree_index: {}", e))
864 })?;
865 let removed_indices: Vec<IndexMetadata> = ds
866 .load_indices()
867 .await
868 .map_err(|e| OmniError::Lance(e.to_string()))?
869 .iter()
870 .filter(|idx| idx.name == new_idx.name)
871 .cloned()
872 .collect();
873 let transaction = TransactionBuilder::new(
874 new_idx.dataset_version,
875 Operation::CreateIndex {
876 new_indices: vec![new_idx],
877 removed_indices,
878 },
879 )
880 .build();
881 Ok(StagedWrite {
882 transaction,
883 new_fragments: Vec::new(),
884 removed_fragment_ids: Vec::new(),
885 })
886 }
887
888 pub async fn stage_create_inverted_index(
892 &self,
893 ds: &Dataset,
894 column: &str,
895 ) -> Result<StagedWrite> {
896 let params = InvertedIndexParams::default();
897 let mut ds_clone = ds.clone();
898 let new_idx = ds_clone
899 .create_index_builder(&[column], IndexType::Inverted, ¶ms)
900 .replace(true)
901 .execute_uncommitted()
902 .await
903 .map_err(|e| {
904 OmniError::Lance(format!("stage_create_inverted_index: {}", e))
905 })?;
906 let removed_indices: Vec<IndexMetadata> = ds
907 .load_indices()
908 .await
909 .map_err(|e| OmniError::Lance(e.to_string()))?
910 .iter()
911 .filter(|idx| idx.name == new_idx.name)
912 .cloned()
913 .collect();
914 let transaction = TransactionBuilder::new(
915 new_idx.dataset_version,
916 Operation::CreateIndex {
917 new_indices: vec![new_idx],
918 removed_indices,
919 },
920 )
921 .build();
922 Ok(StagedWrite {
923 transaction,
924 new_fragments: Vec::new(),
925 removed_fragment_ids: Vec::new(),
926 })
927 }
928
929 pub async fn scan_with_staged(
958 &self,
959 ds: &Dataset,
960 staged: &[StagedWrite],
961 projection: Option<&[&str]>,
962 filter: Option<&str>,
963 ) -> Result<Vec<RecordBatch>> {
964 if staged.is_empty() {
965 return self.scan(ds, projection, filter, None).await;
966 }
967 let mut scanner = ds.scan();
968 if let Some(cols) = projection {
969 let owned: Vec<String> = cols.iter().map(|s| s.to_string()).collect();
970 scanner
971 .project(&owned)
972 .map_err(|e| OmniError::Lance(e.to_string()))?;
973 }
974 if let Some(f) = filter {
975 scanner
976 .filter(f)
977 .map_err(|e| OmniError::Lance(e.to_string()))?;
978 }
979 scanner.with_fragments(combine_committed_with_staged(ds, staged));
980 let stream = scanner
981 .try_into_stream()
982 .await
983 .map_err(|e| OmniError::Lance(e.to_string()))?;
984 stream
985 .try_collect()
986 .await
987 .map_err(|e| OmniError::Lance(e.to_string()))
988 }
989
990 pub async fn scan_with_pending(
1029 &self,
1030 committed_ds: &Dataset,
1031 pending_batches: &[RecordBatch],
1032 pending_schema: Option<SchemaRef>,
1033 projection: Option<&[&str]>,
1034 filter: Option<&str>,
1035 key_column: Option<&str>,
1036 ) -> Result<Vec<RecordBatch>> {
1037 if let (Some(key_col), Some(cols)) = (key_column, projection) {
1046 if !cols.iter().any(|c| *c == key_col) {
1047 return Err(OmniError::Lance(format!(
1048 "scan_with_pending: key_column '{}' must appear in projection \
1049 when merge-shadow semantics are requested (got projection = {:?})",
1050 key_col, cols
1051 )));
1052 }
1053 }
1054
1055 let committed = self.scan(committed_ds, projection, filter, None).await?;
1056 if pending_batches.is_empty() {
1057 return Ok(committed);
1058 }
1059
1060 let committed = match key_column {
1066 Some(key_col) => {
1067 let pending_keys = collect_string_column_values(pending_batches, key_col)?;
1068 if pending_keys.is_empty() {
1069 committed
1070 } else {
1071 filter_out_rows_where_string_in(committed, key_col, &pending_keys)?
1072 }
1073 }
1074 None => committed,
1075 };
1076
1077 let pending = scan_pending_batches(
1078 pending_batches,
1079 pending_schema,
1080 projection,
1081 filter,
1082 )
1083 .await?;
1084
1085 let mut out = committed;
1086 out.extend(pending);
1087 Ok(out)
1088 }
1089
1090 pub async fn count_rows_with_staged(
1095 &self,
1096 ds: &Dataset,
1097 staged: &[StagedWrite],
1098 filter: Option<String>,
1099 ) -> Result<usize> {
1100 if staged.is_empty() {
1101 return self.count_rows(ds, filter).await;
1102 }
1103 let mut scanner = ds.scan();
1104 if let Some(f) = filter {
1105 scanner
1106 .filter(&f)
1107 .map_err(|e| OmniError::Lance(e.to_string()))?;
1108 }
1109 scanner.with_fragments(combine_committed_with_staged(ds, staged));
1110 let count = scanner
1111 .count_rows()
1112 .await
1113 .map_err(|e| OmniError::Lance(e.to_string()))?;
1114 Ok(count as usize)
1115 }
1116
1117 async fn user_indices_for_column(
1118 &self,
1119 ds: &Dataset,
1120 column: &str,
1121 ) -> Result<Vec<IndexMetadata>> {
1122 let field_id = ds
1123 .schema()
1124 .field(column)
1125 .map(|field| field.id)
1126 .ok_or_else(|| {
1127 OmniError::manifest_internal(format!(
1128 "dataset is missing expected index column '{}'",
1129 column
1130 ))
1131 })?;
1132 let indices = ds
1133 .load_indices()
1134 .await
1135 .map_err(|e| OmniError::Lance(e.to_string()))?;
1136 Ok(indices
1137 .iter()
1138 .filter(|index| !is_system_index(index))
1139 .filter(|index| index.fields.len() == 1 && index.fields[0] == field_id)
1140 .cloned()
1141 .collect())
1142 }
1143
1144 pub async fn has_btree_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
1145 let indices = self.user_indices_for_column(ds, column).await?;
1146 Ok(indices.iter().any(|index| {
1147 index
1148 .index_details
1149 .as_ref()
1150 .map(|details| details.type_url.ends_with("BTreeIndexDetails"))
1151 .unwrap_or(false)
1152 }))
1153 }
1154
1155 pub async fn has_fts_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
1156 let indices = self.user_indices_for_column(ds, column).await?;
1157 Ok(indices.iter().any(|index| {
1158 index
1159 .index_details
1160 .as_ref()
1161 .map(|details| IndexDetails(details.clone()).supports_fts())
1162 .unwrap_or(false)
1163 }))
1164 }
1165
1166 pub async fn has_vector_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
1167 let indices = self.user_indices_for_column(ds, column).await?;
1168 Ok(indices.iter().any(|index| {
1169 index
1170 .index_details
1171 .as_ref()
1172 .map(|details| IndexDetails(details.clone()).is_vector())
1173 .unwrap_or(false)
1174 }))
1175 }
1176
1177 pub async fn create_btree_index(&self, ds: &mut Dataset, columns: &[&str]) -> Result<()> {
1178 let params = ScalarIndexParams::default();
1179 ds.create_index_builder(columns, IndexType::BTree, ¶ms)
1180 .replace(true)
1181 .await
1182 .map(|_| ())
1183 .map_err(|e| OmniError::Lance(e.to_string()))
1184 }
1185
1186 pub async fn create_inverted_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
1187 let params = InvertedIndexParams::default();
1188 ds.create_index_builder(&[column], IndexType::Inverted, ¶ms)
1189 .replace(true)
1190 .await
1191 .map(|_| ())
1192 .map_err(|e| OmniError::Lance(e.to_string()))
1193 }
1194
1195 pub async fn create_vector_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
1196 let params = lance::index::vector::VectorIndexParams::ivf_flat(1, MetricType::L2);
1197 ds.create_index_builder(&[column], IndexType::Vector, ¶ms)
1198 .replace(true)
1199 .await
1200 .map(|_| ())
1201 .map_err(|e| OmniError::Lance(e.to_string()))
1202 }
1203
1204 pub async fn create_empty_dataset(dataset_uri: &str, schema: &SchemaRef) -> Result<Dataset> {
1205 let batch = RecordBatch::new_empty(schema.clone());
1206 Self::write_dataset(dataset_uri, batch).await
1207 }
1208
1209 pub async fn first_row_id_for_filter(&self, ds: &Dataset, filter: &str) -> Result<Option<u64>> {
1210 let batches = Self::scan_stream(ds, Some(&["id"]), Some(filter), None, true)
1211 .await?
1212 .try_collect::<Vec<RecordBatch>>()
1213 .await
1214 .map_err(|e| OmniError::Lance(e.to_string()))?;
1215 Ok(batches.iter().find_map(|batch| {
1216 batch
1217 .column_by_name("_rowid")
1218 .and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
1219 .and_then(|arr| (arr.len() > 0).then(|| arr.value(0)))
1220 }))
1221 }
1222
1223 pub async fn write_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
1224 let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
1225 let params = WriteParams {
1226 mode: WriteMode::Create,
1227 enable_stable_row_ids: true,
1228 data_storage_version: Some(LanceFileVersion::V2_2),
1229 allow_external_blob_outside_bases: true,
1230 ..Default::default()
1231 };
1232 Dataset::write(reader, dataset_uri, Some(params))
1233 .await
1234 .map_err(|e| OmniError::Lance(e.to_string()))
1235 }
1236}
1237
1238fn prior_stages_fragment_count(prior_stages: &[StagedWrite]) -> u64 {
1272 prior_stages
1273 .iter()
1274 .map(|s| s.new_fragments.len() as u64)
1275 .sum()
1276}
1277
1278fn assign_fragment_ids(fragments: &mut [Fragment], start_id: u64) {
1286 for (i, fragment) in fragments.iter_mut().enumerate() {
1287 if fragment.id == 0 {
1288 fragment.id = start_id + i as u64;
1289 }
1290 }
1291}
1292
1293fn prior_stages_row_count(prior_stages: &[StagedWrite]) -> Result<u64> {
1294 let mut total: u64 = 0;
1295 for stage in prior_stages {
1296 for fragment in &stage.new_fragments {
1297 let physical_rows = fragment.physical_rows.ok_or_else(|| {
1298 OmniError::manifest_internal(
1299 "prior_stages_row_count: fragment is missing physical_rows".to_string(),
1300 )
1301 })? as u64;
1302 total += physical_rows;
1303 }
1304 }
1305 Ok(total)
1306}
1307
1308fn assign_row_id_meta(fragments: &mut [Fragment], start_row_id: u64) -> Result<()> {
1318 let mut next_row_id = start_row_id;
1319 for fragment in fragments {
1320 if fragment.row_id_meta.is_some() {
1321 continue;
1322 }
1323 let physical_rows = fragment.physical_rows.ok_or_else(|| {
1324 OmniError::manifest_internal(
1325 "stage_append: fragment is missing physical_rows".to_string(),
1326 )
1327 })? as u64;
1328 let row_ids = next_row_id..(next_row_id + physical_rows);
1329 let sequence = RowIdSequence::from(row_ids);
1330 let serialized = write_row_ids(&sequence);
1331 fragment.row_id_meta = Some(RowIdMeta::Inline(serialized));
1332 next_row_id += physical_rows;
1333 }
1334 Ok(())
1335}
1336
1337fn collect_string_column_values(
1342 batches: &[RecordBatch],
1343 column: &str,
1344) -> Result<std::collections::HashSet<String>> {
1345 use arrow_array::{Array, StringArray};
1346 let mut out = std::collections::HashSet::new();
1347 for batch in batches {
1348 let Some(col) = batch.column_by_name(column) else {
1349 return Err(OmniError::Lance(format!(
1350 "scan_with_pending: pending batch missing key column '{}'",
1351 column
1352 )));
1353 };
1354 let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
1355 OmniError::Lance(format!(
1356 "scan_with_pending: key column '{}' is not Utf8",
1357 column
1358 ))
1359 })?;
1360 for i in 0..arr.len() {
1361 if arr.is_valid(i) {
1362 out.insert(arr.value(i).to_string());
1363 }
1364 }
1365 }
1366 Ok(out)
1367}
1368
1369fn filter_out_rows_where_string_in(
1378 batches: Vec<RecordBatch>,
1379 column: &str,
1380 excluded: &std::collections::HashSet<String>,
1381) -> Result<Vec<RecordBatch>> {
1382 use arrow_array::{Array, BooleanArray, StringArray};
1383 let mut out = Vec::with_capacity(batches.len());
1384 for batch in batches {
1385 if batch.num_rows() == 0 {
1386 out.push(batch);
1387 continue;
1388 }
1389 let col = batch.column_by_name(column).ok_or_else(|| {
1390 OmniError::manifest_internal(format!(
1391 "scan_with_pending: committed batch missing key column '{}' \
1392 (the up-front projection check should have rejected this)",
1393 column
1394 ))
1395 })?;
1396 let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
1397 OmniError::Lance(format!(
1398 "scan_with_pending: committed column '{}' is not Utf8",
1399 column
1400 ))
1401 })?;
1402 let mask: BooleanArray = (0..arr.len())
1403 .map(|i| {
1404 if arr.is_valid(i) {
1405 Some(!excluded.contains(arr.value(i)))
1406 } else {
1407 Some(true)
1408 }
1409 })
1410 .collect();
1411 let filtered = arrow_select::filter::filter_record_batch(&batch, &mask)
1412 .map_err(|e| OmniError::Lance(e.to_string()))?;
1413 out.push(filtered);
1414 }
1415 Ok(out)
1416}
1417
1418async fn scan_pending_batches(
1434 pending_batches: &[RecordBatch],
1435 pending_schema: Option<SchemaRef>,
1436 projection: Option<&[&str]>,
1437 filter: Option<&str>,
1438) -> Result<Vec<RecordBatch>> {
1439 let schema = pending_schema.unwrap_or_else(|| pending_batches[0].schema());
1440 let ctx = datafusion::execution::context::SessionContext::new();
1441 let mem = datafusion::datasource::MemTable::try_new(
1442 schema,
1443 vec![pending_batches.to_vec()],
1444 )
1445 .map_err(|e| OmniError::Lance(e.to_string()))?;
1446 ctx.register_table("pending", Arc::new(mem))
1447 .map_err(|e| OmniError::Lance(e.to_string()))?;
1448
1449 let proj = projection
1450 .map(|cols| {
1451 cols.iter()
1452 .map(|c| format!("\"{}\"", c.replace('"', "\"\"")))
1453 .collect::<Vec<_>>()
1454 .join(", ")
1455 })
1456 .unwrap_or_else(|| "*".to_string());
1457 let where_clause = filter
1458 .map(|f| format!("WHERE {f}"))
1459 .unwrap_or_default();
1460 let sql = format!("SELECT {proj} FROM pending {where_clause}");
1461 let df = ctx
1462 .sql(&sql)
1463 .await
1464 .map_err(|e| OmniError::Lance(e.to_string()))?;
1465 df.collect()
1466 .await
1467 .map_err(|e| OmniError::Lance(e.to_string()))
1468}
1469
1470fn combine_committed_with_staged(ds: &Dataset, staged: &[StagedWrite]) -> Vec<Fragment> {
1471 let removed: std::collections::HashSet<u64> = staged
1472 .iter()
1473 .flat_map(|w| w.removed_fragment_ids.iter().copied())
1474 .collect();
1475 let mut combined: Vec<Fragment> = ds
1476 .manifest
1477 .fragments
1478 .iter()
1479 .filter(|f| !removed.contains(&f.id))
1480 .cloned()
1481 .collect();
1482 for write in staged {
1483 combined.extend(write.new_fragments.iter().cloned());
1484 }
1485 combined
1486}