Skip to main content

omnigraph/
table_store.rs

1use arrow_array::{RecordBatch, UInt64Array};
2use arrow_schema::SchemaRef;
3use arrow_select::concat::concat_batches;
4use futures::TryStreamExt;
5use lance::Dataset;
6use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner};
7use lance::dataset::{MergeInsertBuilder, WhenMatched, WhenNotMatched, WriteMode, WriteParams};
8use lance::datatypes::BlobHandling;
9use lance::index::scalar::IndexDetails;
10use lance_file::version::LanceFileVersion;
11use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams};
12use lance_index::{DatasetIndexExt, IndexType, is_system_index};
13use lance_linalg::distance::MetricType;
14use lance_table::format::IndexMetadata;
15use std::sync::Arc;
16
17use crate::db::manifest::{TableVersionMetadata, open_table_head_for_write};
18use crate::db::{Snapshot, SubTableEntry};
19use crate::error::{OmniError, Result};
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct TableState {
23    pub version: u64,
24    pub row_count: u64,
25    pub(crate) version_metadata: TableVersionMetadata,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct DeleteState {
30    pub version: u64,
31    pub row_count: u64,
32    pub deleted_rows: usize,
33    pub(crate) version_metadata: TableVersionMetadata,
34}
35
36#[derive(Debug, Clone)]
37pub struct TableStore {
38    root_uri: String,
39}
40
41impl TableStore {
42    pub fn new(root_uri: &str) -> Self {
43        Self {
44            root_uri: root_uri.trim_end_matches('/').to_string(),
45        }
46    }
47
48    pub fn root_uri(&self) -> &str {
49        &self.root_uri
50    }
51
52    pub fn dataset_uri(&self, table_path: &str) -> String {
53        format!("{}/{}", self.root_uri, table_path)
54    }
55
56    fn table_path_from_dataset_uri(&self, dataset_uri: &str) -> Result<String> {
57        let prefix = format!("{}/", self.root_uri.trim_end_matches('/'));
58        let table_path = dataset_uri
59            .strip_prefix(&prefix)
60            .map(|path| path.to_string())
61            .ok_or_else(|| {
62                OmniError::manifest_internal(format!(
63                    "dataset uri '{}' is not under root '{}'",
64                    dataset_uri, self.root_uri
65                ))
66            })?;
67        Ok(table_path
68            .split_once("/tree/")
69            .map(|(path, _)| path.to_string())
70            .unwrap_or(table_path))
71    }
72
73    fn dataset_version_metadata(
74        &self,
75        dataset_uri: &str,
76        ds: &Dataset,
77    ) -> Result<TableVersionMetadata> {
78        let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
79        TableVersionMetadata::from_dataset(&self.root_uri, &table_path, ds)
80    }
81
82    pub async fn open_snapshot_table(
83        &self,
84        snapshot: &Snapshot,
85        table_key: &str,
86    ) -> Result<Dataset> {
87        snapshot.open(table_key).await
88    }
89
90    pub async fn open_at_entry(&self, entry: &SubTableEntry) -> Result<Dataset> {
91        entry.open(&self.root_uri).await
92    }
93
94    pub async fn open_dataset_head(
95        &self,
96        dataset_uri: &str,
97        branch: Option<&str>,
98    ) -> Result<Dataset> {
99        let ds = Dataset::open(dataset_uri)
100            .await
101            .map_err(|e| OmniError::Lance(e.to_string()))?;
102        match branch {
103            Some(branch) if branch != "main" => ds
104                .checkout_branch(branch)
105                .await
106                .map_err(|e| OmniError::Lance(e.to_string())),
107            _ => Ok(ds),
108        }
109    }
110
111    pub async fn open_dataset_head_for_write(
112        &self,
113        table_key: &str,
114        dataset_uri: &str,
115        branch: Option<&str>,
116    ) -> Result<Dataset> {
117        let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
118        open_table_head_for_write(&self.root_uri, table_key, &table_path, branch).await
119    }
120
121    pub async fn delete_branch(&self, dataset_uri: &str, branch: &str) -> Result<()> {
122        let mut ds = Dataset::open(dataset_uri)
123            .await
124            .map_err(|e| OmniError::Lance(e.to_string()))?;
125        ds.delete_branch(branch)
126            .await
127            .map_err(|e| OmniError::Lance(e.to_string()))
128    }
129
130    pub async fn open_dataset_at_state(
131        &self,
132        table_path: &str,
133        branch: Option<&str>,
134        version: u64,
135    ) -> Result<Dataset> {
136        let ds = self
137            .open_dataset_head(&self.dataset_uri(table_path), branch)
138            .await?;
139        ds.checkout_version(version)
140            .await
141            .map_err(|e| OmniError::Lance(e.to_string()))
142    }
143
144    pub fn ensure_expected_version(
145        &self,
146        ds: &Dataset,
147        table_key: &str,
148        expected_version: u64,
149    ) -> Result<()> {
150        if ds.version().version != expected_version {
151            return Err(OmniError::manifest_conflict(format!(
152                "version drift on {}: snapshot pinned v{} but dataset is at v{} — call sync_branch() and retry",
153                table_key,
154                expected_version,
155                ds.version().version
156            )));
157        }
158        Ok(())
159    }
160
161    pub async fn reopen_for_mutation(
162        &self,
163        dataset_uri: &str,
164        branch: Option<&str>,
165        table_key: &str,
166        expected_version: u64,
167    ) -> Result<Dataset> {
168        let ds = self
169            .open_dataset_head_for_write(table_key, dataset_uri, branch)
170            .await?;
171        self.ensure_expected_version(&ds, table_key, expected_version)?;
172        Ok(ds)
173    }
174
175    pub async fn fork_branch_from_state(
176        &self,
177        dataset_uri: &str,
178        source_branch: Option<&str>,
179        table_key: &str,
180        source_version: u64,
181        target_branch: &str,
182    ) -> Result<Dataset> {
183        let mut source_ds = self
184            .open_dataset_head(dataset_uri, source_branch)
185            .await?
186            .checkout_version(source_version)
187            .await
188            .map_err(|e| OmniError::Lance(e.to_string()))?;
189        self.ensure_expected_version(&source_ds, table_key, source_version)?;
190
191        match source_ds
192            .create_branch(target_branch, source_version, None)
193            .await
194        {
195            Ok(_) => {}
196            Err(create_err) => match self
197                .open_dataset_head(dataset_uri, Some(target_branch))
198                .await
199            {
200                Ok(ds) => {
201                    self.ensure_expected_version(&ds, table_key, source_version)?;
202                    return Ok(ds);
203                }
204                Err(_) => return Err(OmniError::Lance(create_err.to_string())),
205            },
206        }
207
208        let ds = self
209            .open_dataset_head(dataset_uri, Some(target_branch))
210            .await?;
211        self.ensure_expected_version(&ds, table_key, source_version)?;
212        Ok(ds)
213    }
214
215    pub async fn scan_batches(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
216        self.scan(ds, None, None, None).await
217    }
218
219    pub async fn scan_batches_for_rewrite(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
220        let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
221        if !has_blob_columns {
222            return self.scan_batches(ds).await;
223        }
224
225        let mut scanner = ds.scan();
226        scanner.blob_handling(BlobHandling::AllBinary);
227        scanner
228            .try_into_stream()
229            .await
230            .map_err(|e| OmniError::Lance(e.to_string()))?
231            .try_collect()
232            .await
233            .map_err(|e| OmniError::Lance(e.to_string()))
234    }
235
236    pub async fn scan_stream(
237        ds: &Dataset,
238        projection: Option<&[&str]>,
239        filter: Option<&str>,
240        order_by: Option<Vec<ColumnOrdering>>,
241        with_row_id: bool,
242    ) -> Result<DatasetRecordBatchStream> {
243        Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, |_| Ok(())).await
244    }
245
246    pub async fn scan_stream_with<F>(
247        ds: &Dataset,
248        projection: Option<&[&str]>,
249        filter: Option<&str>,
250        order_by: Option<Vec<ColumnOrdering>>,
251        with_row_id: bool,
252        configure: F,
253    ) -> Result<DatasetRecordBatchStream>
254    where
255        F: FnOnce(&mut Scanner) -> Result<()>,
256    {
257        let mut scanner = ds.scan();
258        if with_row_id {
259            scanner.with_row_id();
260        }
261        if let Some(columns) = projection {
262            scanner
263                .project(columns)
264                .map_err(|e| OmniError::Lance(e.to_string()))?;
265        }
266        if let Some(filter_sql) = filter {
267            scanner
268                .filter(filter_sql)
269                .map_err(|e| OmniError::Lance(e.to_string()))?;
270        }
271        if let Some(ordering) = order_by {
272            scanner
273                .order_by(Some(ordering))
274                .map_err(|e| OmniError::Lance(e.to_string()))?;
275        }
276        configure(&mut scanner)?;
277        scanner
278            .try_into_stream()
279            .await
280            .map_err(|e| OmniError::Lance(e.to_string()))
281    }
282
283    pub async fn scan(
284        &self,
285        ds: &Dataset,
286        projection: Option<&[&str]>,
287        filter: Option<&str>,
288        order_by: Option<Vec<ColumnOrdering>>,
289    ) -> Result<Vec<RecordBatch>> {
290        Self::scan_stream(ds, projection, filter, order_by, false)
291            .await?
292            .try_collect()
293            .await
294            .map_err(|e| OmniError::Lance(e.to_string()))
295    }
296
297    pub async fn scan_with<F>(
298        &self,
299        ds: &Dataset,
300        projection: Option<&[&str]>,
301        filter: Option<&str>,
302        order_by: Option<Vec<ColumnOrdering>>,
303        with_row_id: bool,
304        configure: F,
305    ) -> Result<Vec<RecordBatch>>
306    where
307        F: FnOnce(&mut Scanner) -> Result<()>,
308    {
309        Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, configure)
310            .await?
311            .try_collect()
312            .await
313            .map_err(|e| OmniError::Lance(e.to_string()))
314    }
315
316    pub async fn count_rows(&self, ds: &Dataset, filter: Option<String>) -> Result<usize> {
317        ds.count_rows(filter)
318            .await
319            .map(|count| count as usize)
320            .map_err(|e| OmniError::Lance(e.to_string()))
321    }
322
323    pub fn dataset_version(&self, ds: &Dataset) -> u64 {
324        ds.version().version
325    }
326
327    pub async fn table_state(&self, dataset_uri: &str, ds: &Dataset) -> Result<TableState> {
328        Ok(TableState {
329            version: self.dataset_version(ds),
330            row_count: self.count_rows(ds, None).await? as u64,
331            version_metadata: self.dataset_version_metadata(dataset_uri, ds)?,
332        })
333    }
334
335    pub async fn append_batch(
336        &self,
337        dataset_uri: &str,
338        ds: &mut Dataset,
339        batch: RecordBatch,
340    ) -> Result<TableState> {
341        if batch.num_rows() == 0 {
342            return self.table_state(dataset_uri, ds).await;
343        }
344        let schema = batch.schema();
345        let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
346        let params = WriteParams {
347            mode: WriteMode::Append,
348            allow_external_blob_outside_bases: true,
349            ..Default::default()
350        };
351        ds.append(reader, Some(params))
352            .await
353            .map_err(|e| OmniError::Lance(e.to_string()))?;
354        self.table_state(dataset_uri, ds).await
355    }
356
357    pub async fn append_or_create_batch(
358        dataset_uri: &str,
359        dataset: Option<Dataset>,
360        batch: RecordBatch,
361    ) -> Result<Dataset> {
362        let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
363        match dataset {
364            Some(mut ds) => {
365                let params = WriteParams {
366                    mode: WriteMode::Append,
367                    allow_external_blob_outside_bases: true,
368                    ..Default::default()
369                };
370                ds.append(reader, Some(params))
371                    .await
372                    .map_err(|e| OmniError::Lance(e.to_string()))?;
373                Ok(ds)
374            }
375            None => {
376                let params = WriteParams {
377                    mode: WriteMode::Create,
378                    enable_stable_row_ids: true,
379                    data_storage_version: Some(LanceFileVersion::V2_2),
380                    allow_external_blob_outside_bases: true,
381                    ..Default::default()
382                };
383                Dataset::write(reader, dataset_uri, Some(params))
384                    .await
385                    .map_err(|e| OmniError::Lance(e.to_string()))
386            }
387        }
388    }
389
390    pub async fn overwrite_batch(
391        &self,
392        dataset_uri: &str,
393        ds: &mut Dataset,
394        batch: RecordBatch,
395    ) -> Result<TableState> {
396        ds.truncate_table()
397            .await
398            .map_err(|e| OmniError::Lance(e.to_string()))?;
399        self.append_batch(dataset_uri, ds, batch).await
400    }
401
402    pub async fn overwrite_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
403        let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
404        let params = WriteParams {
405            mode: WriteMode::Overwrite,
406            enable_stable_row_ids: true,
407            data_storage_version: Some(LanceFileVersion::V2_2),
408            allow_external_blob_outside_bases: true,
409            ..Default::default()
410        };
411        Dataset::write(reader, dataset_uri, Some(params))
412            .await
413            .map_err(|e| OmniError::Lance(e.to_string()))
414    }
415
416    pub async fn merge_insert_batch(
417        &self,
418        dataset_uri: &str,
419        ds: Dataset,
420        batch: RecordBatch,
421        key_columns: Vec<String>,
422        when_matched: WhenMatched,
423        when_not_matched: WhenNotMatched,
424    ) -> Result<TableState> {
425        if batch.num_rows() == 0 {
426            return self.table_state(dataset_uri, &ds).await;
427        }
428
429        // TODO(lance-upstream): MergeInsertBuilder does not accept WriteParams,
430        // so allow_external_blob_outside_bases cannot be set here. External URI
431        // blobs via merge_insert (LoadMode::Merge, mutations) are unsupported
432        // until Lance exposes WriteParams on MergeInsertBuilder.
433        let ds = Arc::new(ds);
434        let job = MergeInsertBuilder::try_new(ds, key_columns)
435            .map_err(|e| OmniError::Lance(e.to_string()))?
436            .when_matched(when_matched)
437            .when_not_matched(when_not_matched)
438            .try_build()
439            .map_err(|e| OmniError::Lance(e.to_string()))?;
440
441        let schema = batch.schema();
442        let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
443        let (new_ds, _stats) = job
444            .execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
445            .await
446            .map_err(|e| OmniError::Lance(e.to_string()))?;
447        self.table_state(dataset_uri, &new_ds).await
448    }
449
450    pub async fn merge_insert_batches(
451        &self,
452        dataset_uri: &str,
453        ds: Dataset,
454        batches: Vec<RecordBatch>,
455        key_columns: Vec<String>,
456        when_matched: WhenMatched,
457        when_not_matched: WhenNotMatched,
458    ) -> Result<TableState> {
459        if batches.is_empty() {
460            return self.table_state(dataset_uri, &ds).await;
461        }
462        let batch = if batches.len() == 1 {
463            batches.into_iter().next().unwrap()
464        } else {
465            let schema = batches[0].schema();
466            concat_batches(&schema, &batches).map_err(|e| OmniError::Lance(e.to_string()))?
467        };
468        self.merge_insert_batch(
469            dataset_uri,
470            ds,
471            batch,
472            key_columns,
473            when_matched,
474            when_not_matched,
475        )
476        .await
477    }
478
479    pub async fn delete_where(
480        &self,
481        dataset_uri: &str,
482        ds: &mut Dataset,
483        filter: &str,
484    ) -> Result<DeleteState> {
485        let delete_result = ds
486            .delete(filter)
487            .await
488            .map_err(|e| OmniError::Lance(e.to_string()))?;
489        Ok(DeleteState {
490            version: delete_result.new_dataset.version().version,
491            row_count: self.count_rows(&delete_result.new_dataset, None).await? as u64,
492            deleted_rows: delete_result.num_deleted_rows as usize,
493            version_metadata: self
494                .dataset_version_metadata(dataset_uri, &delete_result.new_dataset)?,
495        })
496    }
497
498    async fn user_indices_for_column(
499        &self,
500        ds: &Dataset,
501        column: &str,
502    ) -> Result<Vec<IndexMetadata>> {
503        let field_id = ds
504            .schema()
505            .field(column)
506            .map(|field| field.id)
507            .ok_or_else(|| {
508                OmniError::manifest_internal(format!(
509                    "dataset is missing expected index column '{}'",
510                    column
511                ))
512            })?;
513        let indices = ds
514            .load_indices()
515            .await
516            .map_err(|e| OmniError::Lance(e.to_string()))?;
517        Ok(indices
518            .iter()
519            .filter(|index| !is_system_index(index))
520            .filter(|index| index.fields.len() == 1 && index.fields[0] == field_id)
521            .cloned()
522            .collect())
523    }
524
525    pub async fn has_btree_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
526        let indices = self.user_indices_for_column(ds, column).await?;
527        Ok(indices.iter().any(|index| {
528            index
529                .index_details
530                .as_ref()
531                .map(|details| details.type_url.ends_with("BTreeIndexDetails"))
532                .unwrap_or(false)
533        }))
534    }
535
536    pub async fn has_fts_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
537        let indices = self.user_indices_for_column(ds, column).await?;
538        Ok(indices.iter().any(|index| {
539            index
540                .index_details
541                .as_ref()
542                .map(|details| IndexDetails(details.clone()).supports_fts())
543                .unwrap_or(false)
544        }))
545    }
546
547    pub async fn has_vector_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
548        let indices = self.user_indices_for_column(ds, column).await?;
549        Ok(indices.iter().any(|index| {
550            index
551                .index_details
552                .as_ref()
553                .map(|details| IndexDetails(details.clone()).is_vector())
554                .unwrap_or(false)
555        }))
556    }
557
558    pub async fn create_btree_index(&self, ds: &mut Dataset, columns: &[&str]) -> Result<()> {
559        let params = ScalarIndexParams::default();
560        ds.create_index_builder(columns, IndexType::BTree, &params)
561            .replace(true)
562            .await
563            .map(|_| ())
564            .map_err(|e| OmniError::Lance(e.to_string()))
565    }
566
567    pub async fn create_inverted_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
568        let params = InvertedIndexParams::default();
569        ds.create_index_builder(&[column], IndexType::Inverted, &params)
570            .replace(true)
571            .await
572            .map(|_| ())
573            .map_err(|e| OmniError::Lance(e.to_string()))
574    }
575
576    pub async fn create_vector_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
577        let params = lance::index::vector::VectorIndexParams::ivf_flat(1, MetricType::L2);
578        ds.create_index_builder(&[column], IndexType::Vector, &params)
579            .replace(true)
580            .await
581            .map(|_| ())
582            .map_err(|e| OmniError::Lance(e.to_string()))
583    }
584
585    pub async fn create_empty_dataset(dataset_uri: &str, schema: &SchemaRef) -> Result<Dataset> {
586        let batch = RecordBatch::new_empty(schema.clone());
587        Self::write_dataset(dataset_uri, batch).await
588    }
589
590    pub async fn first_row_id_for_filter(&self, ds: &Dataset, filter: &str) -> Result<Option<u64>> {
591        let batches = Self::scan_stream(ds, Some(&["id"]), Some(filter), None, true)
592            .await?
593            .try_collect::<Vec<RecordBatch>>()
594            .await
595            .map_err(|e| OmniError::Lance(e.to_string()))?;
596        Ok(batches.iter().find_map(|batch| {
597            batch
598                .column_by_name("_rowid")
599                .and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
600                .and_then(|arr| (arr.len() > 0).then(|| arr.value(0)))
601        }))
602    }
603
604    pub async fn write_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
605        let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
606        let params = WriteParams {
607            mode: WriteMode::Create,
608            enable_stable_row_ids: true,
609            data_storage_version: Some(LanceFileVersion::V2_2),
610            allow_external_blob_outside_bases: true,
611            ..Default::default()
612        };
613        Dataset::write(reader, dataset_uri, Some(params))
614            .await
615            .map_err(|e| OmniError::Lance(e.to_string()))
616    }
617}