use arrow_array::{RecordBatch, UInt64Array};
use arrow_schema::SchemaRef;
use arrow_select::concat::concat_batches;
use futures::TryStreamExt;
use lance::Dataset;
use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner};
use lance::dataset::transaction::{Operation, Transaction, TransactionBuilder};
use lance::dataset::{
CommitBuilder, InsertBuilder, MergeInsertBuilder, WhenMatched, WhenNotMatched, WriteMode,
WriteParams,
};
use lance::datatypes::BlobHandling;
use lance::index::scalar::IndexDetails;
use lance_file::version::LanceFileVersion;
use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams};
use lance_index::{DatasetIndexExt, IndexType, is_system_index};
use lance_linalg::distance::MetricType;
use lance_table::format::{Fragment, IndexMetadata, RowIdMeta};
use lance_table::rowids::{RowIdSequence, write_row_ids};
use std::sync::Arc;
use crate::db::manifest::{TableVersionMetadata, open_table_head_for_write};
use crate::db::{Snapshot, SubTableEntry};
use crate::error::{OmniError, Result};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TableState {
pub version: u64,
pub row_count: u64,
pub(crate) version_metadata: TableVersionMetadata,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeleteState {
pub version: u64,
pub row_count: u64,
pub deleted_rows: usize,
pub(crate) version_metadata: TableVersionMetadata,
}
#[derive(Debug, Clone)]
pub struct StagedWrite {
pub transaction: Transaction,
pub new_fragments: Vec<Fragment>,
pub removed_fragment_ids: Vec<u64>,
}
#[derive(Debug, Clone)]
pub struct TableStore {
root_uri: String,
}
impl TableStore {
pub fn new(root_uri: &str) -> Self {
Self {
root_uri: root_uri.trim_end_matches('/').to_string(),
}
}
pub fn root_uri(&self) -> &str {
&self.root_uri
}
pub fn dataset_uri(&self, table_path: &str) -> String {
format!("{}/{}", self.root_uri, table_path)
}
fn table_path_from_dataset_uri(&self, dataset_uri: &str) -> Result<String> {
let prefix = format!("{}/", self.root_uri.trim_end_matches('/'));
let table_path = dataset_uri
.strip_prefix(&prefix)
.map(|path| path.to_string())
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"dataset uri '{}' is not under root '{}'",
dataset_uri, self.root_uri
))
})?;
Ok(table_path
.split_once("/tree/")
.map(|(path, _)| path.to_string())
.unwrap_or(table_path))
}
fn dataset_version_metadata(
&self,
dataset_uri: &str,
ds: &Dataset,
) -> Result<TableVersionMetadata> {
let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
TableVersionMetadata::from_dataset(&self.root_uri, &table_path, ds)
}
pub async fn open_snapshot_table(
&self,
snapshot: &Snapshot,
table_key: &str,
) -> Result<Dataset> {
snapshot.open(table_key).await
}
pub async fn open_at_entry(&self, entry: &SubTableEntry) -> Result<Dataset> {
entry.open(&self.root_uri).await
}
pub async fn open_dataset_head(
&self,
dataset_uri: &str,
branch: Option<&str>,
) -> Result<Dataset> {
let ds = Dataset::open(dataset_uri)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
match branch {
Some(branch) if branch != "main" => ds
.checkout_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string())),
_ => Ok(ds),
}
}
pub async fn open_dataset_head_for_write(
&self,
table_key: &str,
dataset_uri: &str,
branch: Option<&str>,
) -> Result<Dataset> {
let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
open_table_head_for_write(&self.root_uri, table_key, &table_path, branch).await
}
pub async fn delete_branch(&self, dataset_uri: &str, branch: &str) -> Result<()> {
let mut ds = Dataset::open(dataset_uri)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
ds.delete_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn open_dataset_at_state(
&self,
table_path: &str,
branch: Option<&str>,
version: u64,
) -> Result<Dataset> {
let ds = self
.open_dataset_head(&self.dataset_uri(table_path), branch)
.await?;
ds.checkout_version(version)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub fn ensure_expected_version(
&self,
ds: &Dataset,
table_key: &str,
expected_version: u64,
) -> Result<()> {
let actual = ds.version().version;
if actual != expected_version {
return Err(OmniError::manifest_expected_version_mismatch(
table_key,
expected_version,
actual,
));
}
Ok(())
}
pub async fn reopen_for_mutation(
&self,
dataset_uri: &str,
branch: Option<&str>,
table_key: &str,
expected_version: u64,
) -> Result<Dataset> {
let ds = self
.open_dataset_head_for_write(table_key, dataset_uri, branch)
.await?;
self.ensure_expected_version(&ds, table_key, expected_version)?;
Ok(ds)
}
pub async fn fork_branch_from_state(
&self,
dataset_uri: &str,
source_branch: Option<&str>,
table_key: &str,
source_version: u64,
target_branch: &str,
) -> Result<Dataset> {
let mut source_ds = self
.open_dataset_head(dataset_uri, source_branch)
.await?
.checkout_version(source_version)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.ensure_expected_version(&source_ds, table_key, source_version)?;
match source_ds
.create_branch(target_branch, source_version, None)
.await
{
Ok(_) => {}
Err(create_err) => match self
.open_dataset_head(dataset_uri, Some(target_branch))
.await
{
Ok(ds) => {
self.ensure_expected_version(&ds, table_key, source_version)?;
return Ok(ds);
}
Err(_) => return Err(OmniError::Lance(create_err.to_string())),
},
}
let ds = self
.open_dataset_head(dataset_uri, Some(target_branch))
.await?;
self.ensure_expected_version(&ds, table_key, source_version)?;
Ok(ds)
}
pub async fn scan_batches(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
self.scan(ds, None, None, None).await
}
pub async fn scan_batches_for_rewrite(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
if !has_blob_columns {
return self.scan_batches(ds).await;
}
let mut scanner = ds.scan();
scanner.blob_handling(BlobHandling::AllBinary);
scanner
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan_stream(
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
with_row_id: bool,
) -> Result<DatasetRecordBatchStream> {
Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, |_| Ok(())).await
}
pub async fn scan_stream_with<F>(
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
with_row_id: bool,
configure: F,
) -> Result<DatasetRecordBatchStream>
where
F: FnOnce(&mut Scanner) -> Result<()>,
{
let mut scanner = ds.scan();
if with_row_id {
scanner.with_row_id();
}
if let Some(columns) = projection {
scanner
.project(columns)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if let Some(filter_sql) = filter {
scanner
.filter(filter_sql)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if let Some(ordering) = order_by {
scanner
.order_by(Some(ordering))
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
configure(&mut scanner)?;
scanner
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan(
&self,
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
) -> Result<Vec<RecordBatch>> {
Self::scan_stream(ds, projection, filter, order_by, false)
.await?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan_with<F>(
&self,
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
with_row_id: bool,
configure: F,
) -> Result<Vec<RecordBatch>>
where
F: FnOnce(&mut Scanner) -> Result<()>,
{
Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, configure)
.await?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn count_rows(&self, ds: &Dataset, filter: Option<String>) -> Result<usize> {
ds.count_rows(filter)
.await
.map(|count| count as usize)
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub fn dataset_version(&self, ds: &Dataset) -> u64 {
ds.version().version
}
pub async fn table_state(&self, dataset_uri: &str, ds: &Dataset) -> Result<TableState> {
Ok(TableState {
version: self.dataset_version(ds),
row_count: self.count_rows(ds, None).await? as u64,
version_metadata: self.dataset_version_metadata(dataset_uri, ds)?,
})
}
pub async fn append_batch(
&self,
dataset_uri: &str,
ds: &mut Dataset,
batch: RecordBatch,
) -> Result<TableState> {
if batch.num_rows() == 0 {
return self.table_state(dataset_uri, ds).await;
}
let schema = batch.schema();
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
let params = WriteParams {
mode: WriteMode::Append,
allow_external_blob_outside_bases: true,
..Default::default()
};
ds.append(reader, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.table_state(dataset_uri, ds).await
}
pub async fn append_or_create_batch(
dataset_uri: &str,
dataset: Option<Dataset>,
batch: RecordBatch,
) -> Result<Dataset> {
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
match dataset {
Some(mut ds) => {
let params = WriteParams {
mode: WriteMode::Append,
allow_external_blob_outside_bases: true,
..Default::default()
};
ds.append(reader, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(ds)
}
None => {
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
allow_external_blob_outside_bases: true,
..Default::default()
};
Dataset::write(reader, dataset_uri, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
}
}
pub async fn overwrite_batch(
&self,
dataset_uri: &str,
ds: &mut Dataset,
batch: RecordBatch,
) -> Result<TableState> {
ds.truncate_table()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.append_batch(dataset_uri, ds, batch).await
}
pub async fn overwrite_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
let params = WriteParams {
mode: WriteMode::Overwrite,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
allow_external_blob_outside_bases: true,
..Default::default()
};
Dataset::write(reader, dataset_uri, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn merge_insert_batch(
&self,
dataset_uri: &str,
ds: Dataset,
batch: RecordBatch,
key_columns: Vec<String>,
when_matched: WhenMatched,
when_not_matched: WhenNotMatched,
) -> Result<TableState> {
if batch.num_rows() == 0 {
return self.table_state(dataset_uri, &ds).await;
}
let ds = Arc::new(ds);
let job = MergeInsertBuilder::try_new(ds, key_columns)
.map_err(|e| OmniError::Lance(e.to_string()))?
.when_matched(when_matched)
.when_not_matched(when_not_matched)
.try_build()
.map_err(|e| OmniError::Lance(e.to_string()))?;
let schema = batch.schema();
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
let (new_ds, _stats) = job
.execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.table_state(dataset_uri, &new_ds).await
}
pub async fn merge_insert_batches(
&self,
dataset_uri: &str,
ds: Dataset,
batches: Vec<RecordBatch>,
key_columns: Vec<String>,
when_matched: WhenMatched,
when_not_matched: WhenNotMatched,
) -> Result<TableState> {
if batches.is_empty() {
return self.table_state(dataset_uri, &ds).await;
}
let batch = if batches.len() == 1 {
batches.into_iter().next().unwrap()
} else {
let schema = batches[0].schema();
concat_batches(&schema, &batches).map_err(|e| OmniError::Lance(e.to_string()))?
};
self.merge_insert_batch(
dataset_uri,
ds,
batch,
key_columns,
when_matched,
when_not_matched,
)
.await
}
pub async fn delete_where(
&self,
dataset_uri: &str,
ds: &mut Dataset,
filter: &str,
) -> Result<DeleteState> {
let delete_result = ds
.delete(filter)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(DeleteState {
version: delete_result.new_dataset.version().version,
row_count: self.count_rows(&delete_result.new_dataset, None).await? as u64,
deleted_rows: delete_result.num_deleted_rows as usize,
version_metadata: self
.dataset_version_metadata(dataset_uri, &delete_result.new_dataset)?,
})
}
pub async fn stage_append(
&self,
ds: &Dataset,
batch: RecordBatch,
prior_stages: &[StagedWrite],
) -> Result<StagedWrite> {
if batch.num_rows() == 0 {
return Err(OmniError::manifest_internal(
"stage_append called with empty batch".to_string(),
));
}
let params = WriteParams {
mode: WriteMode::Append,
allow_external_blob_outside_bases: true,
..Default::default()
};
let transaction = InsertBuilder::new(Arc::new(ds.clone()))
.with_params(¶ms)
.execute_uncommitted(vec![batch])
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut new_fragments = match &transaction.operation {
Operation::Append { fragments } => fragments.clone(),
Operation::Overwrite { fragments, .. } => fragments.clone(),
other => {
return Err(OmniError::manifest_internal(format!(
"stage_append: unexpected Lance operation {:?}",
std::mem::discriminant(other)
)));
}
};
let next_id_base = ds.manifest.max_fragment_id.unwrap_or(0) as u64
+ 1
+ prior_stages_fragment_count(prior_stages);
assign_fragment_ids(&mut new_fragments, next_id_base);
if ds.manifest.uses_stable_row_ids() {
let prior_rows = prior_stages_row_count(prior_stages)?;
let start_row_id = ds.manifest.next_row_id + prior_rows;
assign_row_id_meta(&mut new_fragments, start_row_id)?;
}
Ok(StagedWrite {
transaction,
new_fragments,
removed_fragment_ids: Vec::new(),
})
}
pub async fn stage_merge_insert(
&self,
ds: Dataset,
batch: RecordBatch,
key_columns: Vec<String>,
when_matched: WhenMatched,
when_not_matched: WhenNotMatched,
) -> Result<StagedWrite> {
if batch.num_rows() == 0 {
return Err(OmniError::manifest_internal(
"stage_merge_insert called with empty batch".to_string(),
));
}
let ds = Arc::new(ds);
let job = MergeInsertBuilder::try_new(ds, key_columns)
.map_err(|e| OmniError::Lance(e.to_string()))?
.when_matched(when_matched)
.when_not_matched(when_not_matched)
.try_build()
.map_err(|e| OmniError::Lance(e.to_string()))?;
let schema = batch.schema();
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
let stream = lance_datafusion::utils::reader_to_stream(Box::new(reader));
let uncommitted = job
.execute_uncommitted(stream)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let (new_fragments, removed_fragment_ids) = match &uncommitted.transaction.operation {
Operation::Update {
new_fragments,
updated_fragments,
removed_fragment_ids,
..
} => {
let mut all = updated_fragments.clone();
all.extend(new_fragments.iter().cloned());
(all, removed_fragment_ids.clone())
}
Operation::Append { fragments } => (fragments.clone(), Vec::new()),
other => {
return Err(OmniError::manifest_internal(format!(
"stage_merge_insert: unexpected Lance operation {:?}",
std::mem::discriminant(other)
)));
}
};
Ok(StagedWrite {
transaction: uncommitted.transaction,
new_fragments,
removed_fragment_ids,
})
}
pub async fn commit_staged(
&self,
ds: Arc<Dataset>,
transaction: Transaction,
) -> Result<Dataset> {
CommitBuilder::new(ds)
.execute(transaction)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn stage_overwrite(
&self,
ds: &Dataset,
batch: RecordBatch,
) -> Result<StagedWrite> {
if batch.num_rows() == 0 {
return Err(OmniError::manifest_internal(
"stage_overwrite called with empty batch".to_string(),
));
}
let params = WriteParams {
mode: WriteMode::Overwrite,
allow_external_blob_outside_bases: true,
..Default::default()
};
let transaction = InsertBuilder::new(Arc::new(ds.clone()))
.with_params(¶ms)
.execute_uncommitted(vec![batch])
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut new_fragments = match &transaction.operation {
Operation::Overwrite { fragments, .. } => fragments.clone(),
other => {
return Err(OmniError::manifest_internal(format!(
"stage_overwrite: unexpected Lance operation {:?}",
std::mem::discriminant(other)
)));
}
};
assign_fragment_ids(&mut new_fragments, 1);
if ds.manifest.uses_stable_row_ids() {
assign_row_id_meta(&mut new_fragments, 0)?;
}
let removed_fragment_ids: Vec<u64> =
ds.manifest.fragments.iter().map(|f| f.id).collect();
Ok(StagedWrite {
transaction,
new_fragments,
removed_fragment_ids,
})
}
pub async fn stage_create_btree_index(
&self,
ds: &Dataset,
columns: &[&str],
) -> Result<StagedWrite> {
let params = ScalarIndexParams::default();
let mut ds_clone = ds.clone();
let new_idx = ds_clone
.create_index_builder(columns, IndexType::BTree, ¶ms)
.replace(true)
.execute_uncommitted()
.await
.map_err(|e| {
OmniError::Lance(format!("stage_create_btree_index: {}", e))
})?;
let removed_indices: Vec<IndexMetadata> = ds
.load_indices()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.iter()
.filter(|idx| idx.name == new_idx.name)
.cloned()
.collect();
let transaction = TransactionBuilder::new(
new_idx.dataset_version,
Operation::CreateIndex {
new_indices: vec![new_idx],
removed_indices,
},
)
.build();
Ok(StagedWrite {
transaction,
new_fragments: Vec::new(),
removed_fragment_ids: Vec::new(),
})
}
pub async fn stage_create_inverted_index(
&self,
ds: &Dataset,
column: &str,
) -> Result<StagedWrite> {
let params = InvertedIndexParams::default();
let mut ds_clone = ds.clone();
let new_idx = ds_clone
.create_index_builder(&[column], IndexType::Inverted, ¶ms)
.replace(true)
.execute_uncommitted()
.await
.map_err(|e| {
OmniError::Lance(format!("stage_create_inverted_index: {}", e))
})?;
let removed_indices: Vec<IndexMetadata> = ds
.load_indices()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.iter()
.filter(|idx| idx.name == new_idx.name)
.cloned()
.collect();
let transaction = TransactionBuilder::new(
new_idx.dataset_version,
Operation::CreateIndex {
new_indices: vec![new_idx],
removed_indices,
},
)
.build();
Ok(StagedWrite {
transaction,
new_fragments: Vec::new(),
removed_fragment_ids: Vec::new(),
})
}
pub async fn scan_with_staged(
&self,
ds: &Dataset,
staged: &[StagedWrite],
projection: Option<&[&str]>,
filter: Option<&str>,
) -> Result<Vec<RecordBatch>> {
if staged.is_empty() {
return self.scan(ds, projection, filter, None).await;
}
let mut scanner = ds.scan();
if let Some(cols) = projection {
let owned: Vec<String> = cols.iter().map(|s| s.to_string()).collect();
scanner
.project(&owned)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if let Some(f) = filter {
scanner
.filter(f)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
scanner.with_fragments(combine_committed_with_staged(ds, staged));
let stream = scanner
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
stream
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan_with_pending(
&self,
committed_ds: &Dataset,
pending_batches: &[RecordBatch],
pending_schema: Option<SchemaRef>,
projection: Option<&[&str]>,
filter: Option<&str>,
key_column: Option<&str>,
) -> Result<Vec<RecordBatch>> {
if let (Some(key_col), Some(cols)) = (key_column, projection) {
if !cols.iter().any(|c| *c == key_col) {
return Err(OmniError::Lance(format!(
"scan_with_pending: key_column '{}' must appear in projection \
when merge-shadow semantics are requested (got projection = {:?})",
key_col, cols
)));
}
}
let committed = self.scan(committed_ds, projection, filter, None).await?;
if pending_batches.is_empty() {
return Ok(committed);
}
let committed = match key_column {
Some(key_col) => {
let pending_keys = collect_string_column_values(pending_batches, key_col)?;
if pending_keys.is_empty() {
committed
} else {
filter_out_rows_where_string_in(committed, key_col, &pending_keys)?
}
}
None => committed,
};
let pending = scan_pending_batches(
pending_batches,
pending_schema,
projection,
filter,
)
.await?;
let mut out = committed;
out.extend(pending);
Ok(out)
}
pub async fn count_rows_with_staged(
&self,
ds: &Dataset,
staged: &[StagedWrite],
filter: Option<String>,
) -> Result<usize> {
if staged.is_empty() {
return self.count_rows(ds, filter).await;
}
let mut scanner = ds.scan();
if let Some(f) = filter {
scanner
.filter(&f)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
scanner.with_fragments(combine_committed_with_staged(ds, staged));
let count = scanner
.count_rows()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(count as usize)
}
async fn user_indices_for_column(
&self,
ds: &Dataset,
column: &str,
) -> Result<Vec<IndexMetadata>> {
let field_id = ds
.schema()
.field(column)
.map(|field| field.id)
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"dataset is missing expected index column '{}'",
column
))
})?;
let indices = ds
.load_indices()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(indices
.iter()
.filter(|index| !is_system_index(index))
.filter(|index| index.fields.len() == 1 && index.fields[0] == field_id)
.cloned()
.collect())
}
pub async fn has_btree_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
let indices = self.user_indices_for_column(ds, column).await?;
Ok(indices.iter().any(|index| {
index
.index_details
.as_ref()
.map(|details| details.type_url.ends_with("BTreeIndexDetails"))
.unwrap_or(false)
}))
}
pub async fn has_fts_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
let indices = self.user_indices_for_column(ds, column).await?;
Ok(indices.iter().any(|index| {
index
.index_details
.as_ref()
.map(|details| IndexDetails(details.clone()).supports_fts())
.unwrap_or(false)
}))
}
pub async fn has_vector_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
let indices = self.user_indices_for_column(ds, column).await?;
Ok(indices.iter().any(|index| {
index
.index_details
.as_ref()
.map(|details| IndexDetails(details.clone()).is_vector())
.unwrap_or(false)
}))
}
pub async fn create_btree_index(&self, ds: &mut Dataset, columns: &[&str]) -> Result<()> {
let params = ScalarIndexParams::default();
ds.create_index_builder(columns, IndexType::BTree, ¶ms)
.replace(true)
.await
.map(|_| ())
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn create_inverted_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
let params = InvertedIndexParams::default();
ds.create_index_builder(&[column], IndexType::Inverted, ¶ms)
.replace(true)
.await
.map(|_| ())
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn create_vector_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
let params = lance::index::vector::VectorIndexParams::ivf_flat(1, MetricType::L2);
ds.create_index_builder(&[column], IndexType::Vector, ¶ms)
.replace(true)
.await
.map(|_| ())
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn create_empty_dataset(dataset_uri: &str, schema: &SchemaRef) -> Result<Dataset> {
let batch = RecordBatch::new_empty(schema.clone());
Self::write_dataset(dataset_uri, batch).await
}
pub async fn first_row_id_for_filter(&self, ds: &Dataset, filter: &str) -> Result<Option<u64>> {
let batches = Self::scan_stream(ds, Some(&["id"]), Some(filter), None, true)
.await?
.try_collect::<Vec<RecordBatch>>()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(batches.iter().find_map(|batch| {
batch
.column_by_name("_rowid")
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
.and_then(|arr| (arr.len() > 0).then(|| arr.value(0)))
}))
}
pub async fn write_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
allow_external_blob_outside_bases: true,
..Default::default()
};
Dataset::write(reader, dataset_uri, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
}
fn prior_stages_fragment_count(prior_stages: &[StagedWrite]) -> u64 {
prior_stages
.iter()
.map(|s| s.new_fragments.len() as u64)
.sum()
}
fn assign_fragment_ids(fragments: &mut [Fragment], start_id: u64) {
for (i, fragment) in fragments.iter_mut().enumerate() {
if fragment.id == 0 {
fragment.id = start_id + i as u64;
}
}
}
fn prior_stages_row_count(prior_stages: &[StagedWrite]) -> Result<u64> {
let mut total: u64 = 0;
for stage in prior_stages {
for fragment in &stage.new_fragments {
let physical_rows = fragment.physical_rows.ok_or_else(|| {
OmniError::manifest_internal(
"prior_stages_row_count: fragment is missing physical_rows".to_string(),
)
})? as u64;
total += physical_rows;
}
}
Ok(total)
}
fn assign_row_id_meta(fragments: &mut [Fragment], start_row_id: u64) -> Result<()> {
let mut next_row_id = start_row_id;
for fragment in fragments {
if fragment.row_id_meta.is_some() {
continue;
}
let physical_rows = fragment.physical_rows.ok_or_else(|| {
OmniError::manifest_internal(
"stage_append: fragment is missing physical_rows".to_string(),
)
})? as u64;
let row_ids = next_row_id..(next_row_id + physical_rows);
let sequence = RowIdSequence::from(row_ids);
let serialized = write_row_ids(&sequence);
fragment.row_id_meta = Some(RowIdMeta::Inline(serialized));
next_row_id += physical_rows;
}
Ok(())
}
fn collect_string_column_values(
batches: &[RecordBatch],
column: &str,
) -> Result<std::collections::HashSet<String>> {
use arrow_array::{Array, StringArray};
let mut out = std::collections::HashSet::new();
for batch in batches {
let Some(col) = batch.column_by_name(column) else {
return Err(OmniError::Lance(format!(
"scan_with_pending: pending batch missing key column '{}'",
column
)));
};
let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
OmniError::Lance(format!(
"scan_with_pending: key column '{}' is not Utf8",
column
))
})?;
for i in 0..arr.len() {
if arr.is_valid(i) {
out.insert(arr.value(i).to_string());
}
}
}
Ok(out)
}
fn filter_out_rows_where_string_in(
batches: Vec<RecordBatch>,
column: &str,
excluded: &std::collections::HashSet<String>,
) -> Result<Vec<RecordBatch>> {
use arrow_array::{Array, BooleanArray, StringArray};
let mut out = Vec::with_capacity(batches.len());
for batch in batches {
if batch.num_rows() == 0 {
out.push(batch);
continue;
}
let col = batch.column_by_name(column).ok_or_else(|| {
OmniError::manifest_internal(format!(
"scan_with_pending: committed batch missing key column '{}' \
(the up-front projection check should have rejected this)",
column
))
})?;
let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
OmniError::Lance(format!(
"scan_with_pending: committed column '{}' is not Utf8",
column
))
})?;
let mask: BooleanArray = (0..arr.len())
.map(|i| {
if arr.is_valid(i) {
Some(!excluded.contains(arr.value(i)))
} else {
Some(true)
}
})
.collect();
let filtered = arrow_select::filter::filter_record_batch(&batch, &mask)
.map_err(|e| OmniError::Lance(e.to_string()))?;
out.push(filtered);
}
Ok(out)
}
async fn scan_pending_batches(
pending_batches: &[RecordBatch],
pending_schema: Option<SchemaRef>,
projection: Option<&[&str]>,
filter: Option<&str>,
) -> Result<Vec<RecordBatch>> {
let schema = pending_schema.unwrap_or_else(|| pending_batches[0].schema());
let ctx = datafusion::execution::context::SessionContext::new();
let mem = datafusion::datasource::MemTable::try_new(
schema,
vec![pending_batches.to_vec()],
)
.map_err(|e| OmniError::Lance(e.to_string()))?;
ctx.register_table("pending", Arc::new(mem))
.map_err(|e| OmniError::Lance(e.to_string()))?;
let proj = projection
.map(|cols| {
cols.iter()
.map(|c| format!("\"{}\"", c.replace('"', "\"\"")))
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_else(|| "*".to_string());
let where_clause = filter
.map(|f| format!("WHERE {f}"))
.unwrap_or_default();
let sql = format!("SELECT {proj} FROM pending {where_clause}");
let df = ctx
.sql(&sql)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
df.collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn combine_committed_with_staged(ds: &Dataset, staged: &[StagedWrite]) -> Vec<Fragment> {
let removed: std::collections::HashSet<u64> = staged
.iter()
.flat_map(|w| w.removed_fragment_ids.iter().copied())
.collect();
let mut combined: Vec<Fragment> = ds
.manifest
.fragments
.iter()
.filter(|f| !removed.contains(&f.id))
.cloned()
.collect();
for write in staged {
combined.extend(write.new_fragments.iter().cloned());
}
combined
}