use arrow_array::{
Array, ArrayRef, RecordBatch, StringArray, StructArray, UInt8Array, UInt32Array, UInt64Array,
};
use arrow_schema::SchemaRef;
use futures::TryStreamExt;
use lance::Dataset;
use lance::blob::BlobArrayBuilder;
use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner};
use lance::dataset::transaction::{Operation, Transaction, TransactionBuilder};
use lance::dataset::write::merge_insert::SourceDedupeBehavior;
use lance::dataset::{
CommitBuilder, InsertBuilder, MergeInsertBuilder, WhenMatched, WhenNotMatched, WriteMode,
WriteParams,
};
use lance::datatypes::{BlobKind, Schema as LanceSchema};
use lance::index::DatasetIndexExt;
use lance::index::scalar::IndexDetails;
use lance_file::version::LanceFileVersion;
use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams};
use lance_index::{IndexType, is_system_index};
use lance_linalg::distance::MetricType;
use lance_table::format::{Fragment, IndexMetadata, RowIdMeta};
use lance_table::rowids::{RowIdSequence, write_row_ids};
use std::sync::Arc;
use crate::db::manifest::{TableVersionMetadata, open_table_head_for_write};
use crate::db::{Snapshot, SubTableEntry};
use crate::error::{OmniError, Result};
use crate::storage_layer::ForkOutcome;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TableState {
pub version: u64,
pub row_count: u64,
pub(crate) version_metadata: TableVersionMetadata,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeleteState {
pub version: u64,
pub row_count: u64,
pub deleted_rows: usize,
pub(crate) version_metadata: TableVersionMetadata,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IndexCoverage {
Indexed,
Degraded { reason: String },
}
#[derive(Debug, Clone)]
pub struct StagedWrite {
pub transaction: Transaction,
pub new_fragments: Vec<Fragment>,
pub removed_fragment_ids: Vec<u64>,
}
#[derive(Debug, Clone)]
pub struct TableStore {
root_uri: String,
}
impl TableStore {
pub fn new(root_uri: &str) -> Self {
Self {
root_uri: root_uri.trim_end_matches('/').to_string(),
}
}
pub fn root_uri(&self) -> &str {
&self.root_uri
}
pub fn dataset_uri(&self, table_path: &str) -> String {
format!("{}/{}", self.root_uri, table_path)
}
fn table_path_from_dataset_uri(&self, dataset_uri: &str) -> Result<String> {
let prefix = format!("{}/", self.root_uri.trim_end_matches('/'));
let table_path = dataset_uri
.strip_prefix(&prefix)
.map(|path| path.to_string())
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"dataset uri '{}' is not under root '{}'",
dataset_uri, self.root_uri
))
})?;
Ok(table_path
.split_once("/tree/")
.map(|(path, _)| path.to_string())
.unwrap_or(table_path))
}
fn dataset_version_metadata(
&self,
dataset_uri: &str,
ds: &Dataset,
) -> Result<TableVersionMetadata> {
let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
TableVersionMetadata::from_dataset(&self.root_uri, &table_path, ds)
}
pub async fn open_snapshot_table(
&self,
snapshot: &Snapshot,
table_key: &str,
) -> Result<Dataset> {
snapshot.open(table_key).await
}
pub async fn open_at_entry(&self, entry: &SubTableEntry) -> Result<Dataset> {
entry.open(&self.root_uri).await
}
pub async fn open_dataset_head(
&self,
dataset_uri: &str,
branch: Option<&str>,
) -> Result<Dataset> {
let ds = Dataset::open(dataset_uri)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
match branch {
Some(branch) if branch != "main" => ds
.checkout_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string())),
_ => Ok(ds),
}
}
pub async fn open_dataset_head_for_write(
&self,
table_key: &str,
dataset_uri: &str,
branch: Option<&str>,
) -> Result<Dataset> {
let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
open_table_head_for_write(&self.root_uri, table_key, &table_path, branch).await
}
pub async fn delete_branch(&self, dataset_uri: &str, branch: &str) -> Result<()> {
let mut ds = Dataset::open(dataset_uri)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
ds.delete_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn list_branches(&self, dataset_uri: &str) -> Result<Vec<String>> {
let ds = Dataset::open(dataset_uri)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let branches = ds
.list_branches()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(branches.into_keys().collect())
}
pub async fn force_delete_branch(&self, dataset_uri: &str, branch: &str) -> Result<()> {
let mut ds = Dataset::open(dataset_uri)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
match ds.force_delete_branch(branch).await {
Ok(()) => Ok(()),
Err(lance::Error::RefNotFound { .. }) | Err(lance::Error::NotFound { .. }) => Ok(()),
Err(e) => Err(OmniError::Lance(e.to_string())),
}
}
pub async fn open_dataset_at_state(
&self,
table_path: &str,
branch: Option<&str>,
version: u64,
) -> Result<Dataset> {
let ds = self
.open_dataset_head(&self.dataset_uri(table_path), branch)
.await?;
ds.checkout_version(version)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub fn ensure_expected_version(
&self,
ds: &Dataset,
table_key: &str,
expected_version: u64,
) -> Result<()> {
let actual = ds.version().version;
if actual != expected_version {
return Err(OmniError::manifest_expected_version_mismatch(
table_key,
expected_version,
actual,
));
}
Ok(())
}
pub async fn reopen_for_mutation(
&self,
dataset_uri: &str,
branch: Option<&str>,
table_key: &str,
expected_version: u64,
) -> Result<Dataset> {
let ds = self
.open_dataset_head_for_write(table_key, dataset_uri, branch)
.await?;
self.ensure_expected_version(&ds, table_key, expected_version)?;
Ok(ds)
}
pub async fn fork_branch_from_state(
&self,
dataset_uri: &str,
source_branch: Option<&str>,
table_key: &str,
source_version: u64,
target_branch: &str,
) -> Result<ForkOutcome<Dataset>> {
let mut source_ds = self
.open_dataset_head(dataset_uri, source_branch)
.await?
.checkout_version(source_version)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.ensure_expected_version(&source_ds, table_key, source_version)?;
if let Err(create_err) = source_ds
.create_branch(target_branch, source_version, None)
.await
{
let ref_exists = source_ds
.list_branches()
.await
.map(|b| b.contains_key(target_branch))
.unwrap_or(false);
if ref_exists {
return Ok(ForkOutcome::RefAlreadyExists);
}
return Err(OmniError::Lance(create_err.to_string()));
}
let ds = self
.open_dataset_head(dataset_uri, Some(target_branch))
.await?;
self.ensure_expected_version(&ds, table_key, source_version)?;
Ok(ForkOutcome::Created(ds))
}
pub async fn scan_batches(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
self.scan(ds, None, None, None).await
}
pub async fn scan_batches_for_rewrite(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
if !has_blob_columns {
return self.scan_batches(ds).await;
}
let batches = Self::scan_stream(ds, None, None, None, true)
.await?
.try_collect::<Vec<RecordBatch>>()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut materialized = Vec::with_capacity(batches.len());
for batch in batches {
materialized.push(Self::materialize_blob_batch(ds, batch).await?);
}
Ok(materialized)
}
pub(crate) async fn materialize_blob_batch(
ds: &Dataset,
batch: RecordBatch,
) -> Result<RecordBatch> {
let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
if !has_blob_columns {
return Ok(batch);
}
let row_ids = batch
.column_by_name("_rowid")
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
.ok_or_else(|| {
OmniError::Lance("expected _rowid column when materializing blobs".to_string())
})?
.values()
.iter()
.copied()
.collect::<Vec<_>>();
let schema: SchemaRef = Arc::new(ds.schema().into());
let mut columns = Vec::with_capacity(schema.fields().len());
for field in schema.fields() {
let lance_field = lance::datatypes::Field::try_from(field.as_ref())
.map_err(|e| OmniError::Lance(e.to_string()))?;
let column = batch.column_by_name(field.name()).ok_or_else(|| {
OmniError::Lance(format!("batch missing column '{}'", field.name()))
})?;
if lance_field.is_blob() {
let descriptions =
column
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
OmniError::Lance(format!(
"expected blob descriptions for '{}'",
field.name()
))
})?;
columns.push(
Self::rebuild_blob_column(ds, field.name(), descriptions, &row_ids).await?,
);
} else {
columns.push(column.clone());
}
}
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
}
async fn rebuild_blob_column(
ds: &Dataset,
column_name: &str,
descriptions: &StructArray,
row_ids: &[u64],
) -> Result<ArrayRef> {
let mut builder = BlobArrayBuilder::new(row_ids.len());
let mut non_null_row_ids = Vec::new();
let mut row_has_blob = Vec::with_capacity(row_ids.len());
for row in 0..row_ids.len() {
let is_null = Self::blob_description_is_null(descriptions, row)?;
row_has_blob.push(!is_null);
if !is_null {
non_null_row_ids.push(row_ids[row]);
}
}
let blob_files = if non_null_row_ids.is_empty() {
Vec::new()
} else {
Arc::new(ds.clone())
.take_blobs(&non_null_row_ids, column_name)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
};
let mut files = blob_files.into_iter();
for has_blob in row_has_blob {
if !has_blob {
builder
.push_null()
.map_err(|e| OmniError::Lance(e.to_string()))?;
continue;
}
let blob = files.next().ok_or_else(|| {
OmniError::Lance(format!(
"blob rewrite for '{}' lost alignment with source rows",
column_name
))
})?;
builder
.push_bytes(
blob.read()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?,
)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if files.next().is_some() {
return Err(OmniError::Lance(format!(
"blob rewrite for '{}' produced extra source blobs",
column_name
)));
}
builder
.finish()
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn blob_description_is_null(descriptions: &StructArray, row: usize) -> Result<bool> {
if descriptions.is_null(row) {
return Ok(true);
}
let position = descriptions
.column_by_name("position")
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
.ok_or_else(|| {
OmniError::Lance(format!(
"unrecognized blob description schema {:?}: missing UInt64 position field",
descriptions.fields()
))
})?;
let size = descriptions
.column_by_name("size")
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
.ok_or_else(|| {
OmniError::Lance(format!(
"unrecognized blob description schema {:?}: missing UInt64 size field",
descriptions.fields()
))
})?;
let Some(kind_column) = descriptions.column_by_name("kind") else {
return Ok(position.is_null(row) || size.is_null(row));
};
let kind = if let Some(kind) = kind_column.as_any().downcast_ref::<UInt8Array>() {
if kind.is_null(row) {
return Ok(true);
}
kind.value(row)
} else if let Some(kind) = kind_column.as_any().downcast_ref::<UInt32Array>() {
if kind.is_null(row) {
return Ok(true);
}
kind.value(row) as u8
} else {
return Err(OmniError::Lance(format!(
"unrecognized blob description schema {:?}: kind field must be UInt8 or UInt32",
descriptions.fields()
)));
};
let kind = BlobKind::try_from(kind).map_err(|e| OmniError::Lance(e.to_string()))?;
if kind != BlobKind::Inline {
return Ok(false);
}
let blob_uri = descriptions
.column_by_name("blob_uri")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.and_then(|arr| (!arr.is_null(row)).then(|| arr.value(row)));
Ok((position.is_null(row) || position.value(row) == 0)
&& (size.is_null(row) || size.value(row) == 0)
&& blob_uri.unwrap_or("").is_empty())
}
pub async fn scan_stream(
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
with_row_id: bool,
) -> Result<DatasetRecordBatchStream> {
Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, |_| Ok(())).await
}
pub async fn scan_stream_with<F>(
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
with_row_id: bool,
configure: F,
) -> Result<DatasetRecordBatchStream>
where
F: FnOnce(&mut Scanner) -> Result<()>,
{
let mut scanner = ds.scan();
if with_row_id {
scanner.with_row_id();
}
if let Some(columns) = projection {
scanner
.project(columns)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if let Some(filter_sql) = filter {
scanner
.filter(filter_sql)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if let Some(ordering) = order_by {
scanner
.order_by(Some(ordering))
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
configure(&mut scanner)?;
scanner
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan(
&self,
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
) -> Result<Vec<RecordBatch>> {
Self::scan_stream(ds, projection, filter, order_by, false)
.await?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan_with<F>(
&self,
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
with_row_id: bool,
configure: F,
) -> Result<Vec<RecordBatch>>
where
F: FnOnce(&mut Scanner) -> Result<()>,
{
Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, configure)
.await?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan_edges_by_endpoint(
ds: &Dataset,
key_col: &str,
opposite_col: &str,
keys: &[String],
) -> Result<Vec<RecordBatch>> {
use datafusion::prelude::{col, lit};
if keys.is_empty() {
return Ok(Vec::new());
}
let key_list: Vec<datafusion::prelude::Expr> =
keys.iter().map(|k| lit(k.clone())).collect();
let filter_expr = col(key_col).in_list(key_list, false);
Self::scan_stream_with(
ds,
Some(&[key_col, opposite_col]),
None,
None,
false,
|scanner| {
scanner.filter_expr(filter_expr);
Ok(())
},
)
.await?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn key_column_index_coverage(ds: &Dataset, column: &str) -> Result<IndexCoverage> {
let Some(field_id) = ds.schema().field(column).map(|field| field.id) else {
return Ok(IndexCoverage::Degraded {
reason: format!("column '{}' not in schema", column),
});
};
let indices = ds
.load_indices()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let btree = indices
.iter()
.filter(|index| !is_system_index(index))
.filter(|index| index.fields.len() == 1 && index.fields[0] == field_id)
.find(|index| {
index
.index_details
.as_ref()
.map(|details| details.type_url.ends_with("BTreeIndexDetails"))
.unwrap_or(false)
});
let Some(btree) = btree else {
return Ok(IndexCoverage::Degraded {
reason: format!("no BTREE index on '{}'", column),
});
};
if ds.fragments().iter().any(|f| f.physical_rows.is_none()) {
return Ok(IndexCoverage::Degraded {
reason: "a fragment is missing physical_rows".to_string(),
});
}
if let Some(bitmap) = btree.fragment_bitmap.as_ref() {
let uncovered = ds
.fragments()
.iter()
.filter(|f| !bitmap.contains(f.id as u32))
.count();
if uncovered > 0 {
return Ok(IndexCoverage::Degraded {
reason: format!(
"{} fragment(s) not covered by the index on '{}'",
uncovered, column
),
});
}
}
Ok(IndexCoverage::Indexed)
}
pub async fn has_unindexed_fragments(ds: &Dataset) -> Result<bool> {
let indices = ds
.load_indices()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let frag_ids: Vec<u32> = ds.fragments().iter().map(|f| f.id as u32).collect();
for index in indices.iter() {
if is_system_index(index) {
continue;
}
if let Some(bitmap) = index.fragment_bitmap.as_ref() {
if frag_ids.iter().any(|id| !bitmap.contains(*id)) {
return Ok(true);
}
}
}
Ok(false)
}
pub async fn count_rows(&self, ds: &Dataset, filter: Option<String>) -> Result<usize> {
ds.count_rows(filter)
.await
.map(|count| count as usize)
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub fn dataset_version(&self, ds: &Dataset) -> u64 {
ds.version().version
}
pub async fn table_state(&self, dataset_uri: &str, ds: &Dataset) -> Result<TableState> {
Ok(TableState {
version: self.dataset_version(ds),
row_count: self.count_rows(ds, None).await? as u64,
version_metadata: self.dataset_version_metadata(dataset_uri, ds)?,
})
}
pub(crate) async fn append_batch(
&self,
dataset_uri: &str,
ds: &mut Dataset,
batch: RecordBatch,
) -> Result<TableState> {
if batch.num_rows() == 0 {
return self.table_state(dataset_uri, ds).await;
}
let schema = batch.schema();
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
let params = WriteParams {
mode: WriteMode::Append,
allow_external_blob_outside_bases: true,
auto_cleanup: None,
skip_auto_cleanup: true,
..Default::default()
};
ds.append(reader, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.table_state(dataset_uri, ds).await
}
pub async fn append_or_create_batch(
dataset_uri: &str,
dataset: Option<Dataset>,
batch: RecordBatch,
) -> Result<Dataset> {
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
match dataset {
Some(mut ds) => {
let params = WriteParams {
mode: WriteMode::Append,
allow_external_blob_outside_bases: true,
auto_cleanup: None,
skip_auto_cleanup: true,
..Default::default()
};
ds.append(reader, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(ds)
}
None => {
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
allow_external_blob_outside_bases: true,
auto_cleanup: None,
skip_auto_cleanup: true,
..Default::default()
};
Dataset::write(reader, dataset_uri, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
}
}
pub(crate) async fn delete_where(
&self,
dataset_uri: &str,
ds: &mut Dataset,
filter: &str,
) -> Result<DeleteState> {
let delete_result = ds
.delete(filter)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(DeleteState {
version: delete_result.new_dataset.version().version,
row_count: self.count_rows(&delete_result.new_dataset, None).await? as u64,
deleted_rows: delete_result.num_deleted_rows as usize,
version_metadata: self
.dataset_version_metadata(dataset_uri, &delete_result.new_dataset)?,
})
}
pub async fn stage_append(
&self,
ds: &Dataset,
batch: RecordBatch,
prior_stages: &[StagedWrite],
) -> Result<StagedWrite> {
if batch.num_rows() == 0 {
return Err(OmniError::manifest_internal(
"stage_append called with empty batch".to_string(),
));
}
let params = WriteParams {
mode: WriteMode::Append,
allow_external_blob_outside_bases: true,
auto_cleanup: None,
skip_auto_cleanup: true,
..Default::default()
};
let transaction = InsertBuilder::new(Arc::new(ds.clone()))
.with_params(¶ms)
.execute_uncommitted(vec![batch])
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut new_fragments = match &transaction.operation {
Operation::Append { fragments } => fragments.clone(),
Operation::Overwrite { fragments, .. } => fragments.clone(),
other => {
return Err(OmniError::manifest_internal(format!(
"stage_append: unexpected Lance operation {:?}",
std::mem::discriminant(other)
)));
}
};
let next_id_base = ds.manifest.max_fragment_id.unwrap_or(0) as u64
+ 1
+ prior_stages_fragment_count(prior_stages);
assign_fragment_ids(&mut new_fragments, next_id_base);
if ds.manifest.uses_stable_row_ids() {
let prior_rows = prior_stages_row_count(prior_stages)?;
let start_row_id = ds.manifest.next_row_id + prior_rows;
assign_row_id_meta(&mut new_fragments, start_row_id)?;
}
Ok(StagedWrite {
transaction,
new_fragments,
removed_fragment_ids: Vec::new(),
})
}
pub async fn stage_merge_insert(
&self,
ds: Dataset,
batch: RecordBatch,
key_columns: Vec<String>,
when_matched: WhenMatched,
when_not_matched: WhenNotMatched,
) -> Result<StagedWrite> {
if batch.num_rows() == 0 {
return Err(OmniError::manifest_internal(
"stage_merge_insert called with empty batch".to_string(),
));
}
check_batch_unique_by_keys(&batch, &key_columns, "stage_merge_insert")?;
let ds = Arc::new(ds);
let mut builder = MergeInsertBuilder::try_new(ds, key_columns)
.map_err(|e| OmniError::Lance(e.to_string()))?;
builder.when_matched(when_matched);
builder.when_not_matched(when_not_matched);
builder.source_dedupe_behavior(SourceDedupeBehavior::FirstSeen);
let job = builder
.try_build()
.map_err(|e| OmniError::Lance(e.to_string()))?;
let schema = batch.schema();
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
let stream = lance_datafusion::utils::reader_to_stream(Box::new(reader));
let uncommitted = job
.execute_uncommitted(stream)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let (new_fragments, removed_fragment_ids) = match &uncommitted.transaction.operation {
Operation::Update {
new_fragments,
updated_fragments,
removed_fragment_ids,
..
} => {
let mut all = updated_fragments.clone();
all.extend(new_fragments.iter().cloned());
(all, removed_fragment_ids.clone())
}
Operation::Append { fragments } => (fragments.clone(), Vec::new()),
other => {
return Err(OmniError::manifest_internal(format!(
"stage_merge_insert: unexpected Lance operation {:?}",
std::mem::discriminant(other)
)));
}
};
Ok(StagedWrite {
transaction: uncommitted.transaction,
new_fragments,
removed_fragment_ids,
})
}
pub async fn commit_staged(
&self,
ds: Arc<Dataset>,
transaction: Transaction,
) -> Result<Dataset> {
CommitBuilder::new(ds)
.with_skip_auto_cleanup(true)
.execute(transaction)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn stage_overwrite(&self, ds: &Dataset, batch: RecordBatch) -> Result<StagedWrite> {
let (transaction, mut new_fragments) = if batch.num_rows() == 0 {
let schema = LanceSchema::try_from(batch.schema().as_ref())
.map_err(|e| OmniError::Lance(e.to_string()))?;
let transaction = TransactionBuilder::new(
ds.manifest.version,
Operation::Overwrite {
fragments: Vec::new(),
schema,
config_upsert_values: None,
initial_bases: None,
},
)
.build();
(transaction, Vec::new())
} else {
let params = WriteParams {
mode: WriteMode::Overwrite,
enable_stable_row_ids: true,
allow_external_blob_outside_bases: true,
auto_cleanup: None,
skip_auto_cleanup: true,
..Default::default()
};
let transaction = InsertBuilder::new(Arc::new(ds.clone()))
.with_params(¶ms)
.execute_uncommitted(vec![batch])
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let new_fragments = match &transaction.operation {
Operation::Overwrite { fragments, .. } => fragments.clone(),
other => {
return Err(OmniError::manifest_internal(format!(
"stage_overwrite: unexpected Lance operation {:?}",
std::mem::discriminant(other)
)));
}
};
(transaction, new_fragments)
};
assign_fragment_ids(&mut new_fragments, 1);
if ds.manifest.uses_stable_row_ids() {
assign_row_id_meta(&mut new_fragments, 0)?;
}
let removed_fragment_ids: Vec<u64> = ds.manifest.fragments.iter().map(|f| f.id).collect();
Ok(StagedWrite {
transaction,
new_fragments,
removed_fragment_ids,
})
}
pub async fn stage_create_btree_index(
&self,
ds: &Dataset,
columns: &[&str],
) -> Result<StagedWrite> {
let params = ScalarIndexParams::default();
let mut ds_clone = ds.clone();
let new_idx = ds_clone
.create_index_builder(columns, IndexType::BTree, ¶ms)
.replace(true)
.execute_uncommitted()
.await
.map_err(|e| OmniError::Lance(format!("stage_create_btree_index: {}", e)))?;
let removed_indices: Vec<IndexMetadata> = ds
.load_indices()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.iter()
.filter(|idx| idx.name == new_idx.name)
.cloned()
.collect();
let transaction = TransactionBuilder::new(
new_idx.dataset_version,
Operation::CreateIndex {
new_indices: vec![new_idx],
removed_indices,
},
)
.build();
Ok(StagedWrite {
transaction,
new_fragments: Vec::new(),
removed_fragment_ids: Vec::new(),
})
}
pub async fn stage_create_inverted_index(
&self,
ds: &Dataset,
column: &str,
) -> Result<StagedWrite> {
let params = InvertedIndexParams::default();
let mut ds_clone = ds.clone();
let new_idx = ds_clone
.create_index_builder(&[column], IndexType::Inverted, ¶ms)
.replace(true)
.execute_uncommitted()
.await
.map_err(|e| OmniError::Lance(format!("stage_create_inverted_index: {}", e)))?;
let removed_indices: Vec<IndexMetadata> = ds
.load_indices()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.iter()
.filter(|idx| idx.name == new_idx.name)
.cloned()
.collect();
let transaction = TransactionBuilder::new(
new_idx.dataset_version,
Operation::CreateIndex {
new_indices: vec![new_idx],
removed_indices,
},
)
.build();
Ok(StagedWrite {
transaction,
new_fragments: Vec::new(),
removed_fragment_ids: Vec::new(),
})
}
pub async fn scan_with_staged(
&self,
ds: &Dataset,
staged: &[StagedWrite],
projection: Option<&[&str]>,
filter: Option<&str>,
) -> Result<Vec<RecordBatch>> {
if staged.is_empty() {
return self.scan(ds, projection, filter, None).await;
}
let mut scanner = ds.scan();
if let Some(cols) = projection {
let owned: Vec<String> = cols.iter().map(|s| s.to_string()).collect();
scanner
.project(&owned)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if let Some(f) = filter {
scanner
.filter(f)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
scanner.with_fragments(combine_committed_with_staged(ds, staged));
let stream = scanner
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
stream
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan_with_pending(
&self,
committed_ds: &Dataset,
pending_batches: &[RecordBatch],
pending_schema: Option<SchemaRef>,
projection: Option<&[&str]>,
filter: Option<&str>,
key_column: Option<&str>,
) -> Result<Vec<RecordBatch>> {
if let (Some(key_col), Some(cols)) = (key_column, projection) {
if !cols.iter().any(|c| *c == key_col) {
return Err(OmniError::Lance(format!(
"scan_with_pending: key_column '{}' must appear in projection \
when merge-shadow semantics are requested (got projection = {:?})",
key_col, cols
)));
}
}
let committed = self.scan(committed_ds, projection, filter, None).await?;
if pending_batches.is_empty() {
return Ok(committed);
}
let committed = match key_column {
Some(key_col) => {
let pending_keys = collect_string_column_values(pending_batches, key_col)?;
if pending_keys.is_empty() {
committed
} else {
filter_out_rows_where_string_in(committed, key_col, &pending_keys)?
}
}
None => committed,
};
let pending =
scan_pending_batches(pending_batches, pending_schema, projection, filter).await?;
let mut out = committed;
out.extend(pending);
Ok(out)
}
pub async fn count_rows_with_staged(
&self,
ds: &Dataset,
staged: &[StagedWrite],
filter: Option<String>,
) -> Result<usize> {
if staged.is_empty() {
return self.count_rows(ds, filter).await;
}
let mut scanner = ds.scan();
if let Some(f) = filter {
scanner
.filter(&f)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
scanner.with_fragments(combine_committed_with_staged(ds, staged));
let count = scanner
.count_rows()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(count as usize)
}
async fn user_indices_for_column(
&self,
ds: &Dataset,
column: &str,
) -> Result<Vec<IndexMetadata>> {
let field_id = ds
.schema()
.field(column)
.map(|field| field.id)
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"dataset is missing expected index column '{}'",
column
))
})?;
let indices = ds
.load_indices()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(indices
.iter()
.filter(|index| !is_system_index(index))
.filter(|index| index.fields.len() == 1 && index.fields[0] == field_id)
.cloned()
.collect())
}
pub async fn has_btree_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
let indices = self.user_indices_for_column(ds, column).await?;
Ok(indices.iter().any(|index| {
index
.index_details
.as_ref()
.map(|details| details.type_url.ends_with("BTreeIndexDetails"))
.unwrap_or(false)
}))
}
pub async fn has_fts_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
let indices = self.user_indices_for_column(ds, column).await?;
Ok(indices.iter().any(|index| {
index
.index_details
.as_ref()
.map(|details| IndexDetails(details.clone()).supports_fts())
.unwrap_or(false)
}))
}
pub async fn has_vector_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
let indices = self.user_indices_for_column(ds, column).await?;
Ok(indices.iter().any(|index| {
index
.index_details
.as_ref()
.map(|details| IndexDetails(details.clone()).is_vector())
.unwrap_or(false)
}))
}
pub(crate) async fn create_vector_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
let params = lance::index::vector::VectorIndexParams::ivf_flat(1, MetricType::L2);
ds.create_index_builder(&[column], IndexType::Vector, ¶ms)
.replace(true)
.await
.map(|_| ())
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn create_empty_dataset(dataset_uri: &str, schema: &SchemaRef) -> Result<Dataset> {
let batch = RecordBatch::new_empty(schema.clone());
Self::write_dataset(dataset_uri, batch).await
}
pub async fn first_row_id_for_filter(&self, ds: &Dataset, filter: &str) -> Result<Option<u64>> {
let batches = Self::scan_stream(ds, Some(&["id"]), Some(filter), None, true)
.await?
.try_collect::<Vec<RecordBatch>>()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(batches.iter().find_map(|batch| {
batch
.column_by_name("_rowid")
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
.and_then(|arr| (arr.len() > 0).then(|| arr.value(0)))
}))
}
pub async fn write_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
allow_external_blob_outside_bases: true,
auto_cleanup: None,
skip_auto_cleanup: true,
..Default::default()
};
Dataset::write(reader, dataset_uri, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
}
fn prior_stages_fragment_count(prior_stages: &[StagedWrite]) -> u64 {
prior_stages
.iter()
.map(|s| s.new_fragments.len() as u64)
.sum()
}
fn assign_fragment_ids(fragments: &mut [Fragment], start_id: u64) {
for (i, fragment) in fragments.iter_mut().enumerate() {
if fragment.id == 0 {
fragment.id = start_id + i as u64;
}
}
}
fn prior_stages_row_count(prior_stages: &[StagedWrite]) -> Result<u64> {
let mut total: u64 = 0;
for stage in prior_stages {
for fragment in &stage.new_fragments {
let physical_rows = fragment.physical_rows.ok_or_else(|| {
OmniError::manifest_internal(
"prior_stages_row_count: fragment is missing physical_rows".to_string(),
)
})? as u64;
total += physical_rows;
}
}
Ok(total)
}
fn assign_row_id_meta(fragments: &mut [Fragment], start_row_id: u64) -> Result<()> {
let mut next_row_id = start_row_id;
for fragment in fragments {
if fragment.row_id_meta.is_some() {
continue;
}
let physical_rows = fragment.physical_rows.ok_or_else(|| {
OmniError::manifest_internal(
"stage_append: fragment is missing physical_rows".to_string(),
)
})? as u64;
let row_ids = next_row_id..(next_row_id + physical_rows);
let sequence = RowIdSequence::from(row_ids);
let serialized = write_row_ids(&sequence);
fragment.row_id_meta = Some(RowIdMeta::Inline(serialized));
next_row_id += physical_rows;
}
Ok(())
}
fn collect_string_column_values(
batches: &[RecordBatch],
column: &str,
) -> Result<std::collections::HashSet<String>> {
use arrow_array::{Array, StringArray};
let mut out = std::collections::HashSet::new();
for batch in batches {
let Some(col) = batch.column_by_name(column) else {
return Err(OmniError::Lance(format!(
"scan_with_pending: pending batch missing key column '{}'",
column
)));
};
let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
OmniError::Lance(format!(
"scan_with_pending: key column '{}' is not Utf8",
column
))
})?;
for i in 0..arr.len() {
if arr.is_valid(i) {
out.insert(arr.value(i).to_string());
}
}
}
Ok(out)
}
fn filter_out_rows_where_string_in(
batches: Vec<RecordBatch>,
column: &str,
excluded: &std::collections::HashSet<String>,
) -> Result<Vec<RecordBatch>> {
use arrow_array::{Array, BooleanArray, StringArray};
let mut out = Vec::with_capacity(batches.len());
for batch in batches {
if batch.num_rows() == 0 {
out.push(batch);
continue;
}
let col = batch.column_by_name(column).ok_or_else(|| {
OmniError::manifest_internal(format!(
"scan_with_pending: committed batch missing key column '{}' \
(the up-front projection check should have rejected this)",
column
))
})?;
let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
OmniError::Lance(format!(
"scan_with_pending: committed column '{}' is not Utf8",
column
))
})?;
let mask: BooleanArray = (0..arr.len())
.map(|i| {
if arr.is_valid(i) {
Some(!excluded.contains(arr.value(i)))
} else {
Some(true)
}
})
.collect();
let filtered = arrow_select::filter::filter_record_batch(&batch, &mask)
.map_err(|e| OmniError::Lance(e.to_string()))?;
out.push(filtered);
}
Ok(out)
}
async fn scan_pending_batches(
pending_batches: &[RecordBatch],
pending_schema: Option<SchemaRef>,
projection: Option<&[&str]>,
filter: Option<&str>,
) -> Result<Vec<RecordBatch>> {
let schema = pending_schema.unwrap_or_else(|| pending_batches[0].schema());
let ctx = datafusion::execution::context::SessionContext::new();
let mem = datafusion::datasource::MemTable::try_new(schema, vec![pending_batches.to_vec()])
.map_err(|e| OmniError::Lance(e.to_string()))?;
ctx.register_table("pending", Arc::new(mem))
.map_err(|e| OmniError::Lance(e.to_string()))?;
let proj = projection
.map(|cols| {
cols.iter()
.map(|c| format!("\"{}\"", c.replace('"', "\"\"")))
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_else(|| "*".to_string());
let where_clause = filter.map(|f| format!("WHERE {f}")).unwrap_or_default();
let sql = format!("SELECT {proj} FROM pending {where_clause}");
let df = ctx
.sql(&sql)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
df.collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn combine_committed_with_staged(ds: &Dataset, staged: &[StagedWrite]) -> Vec<Fragment> {
let removed: std::collections::HashSet<u64> = staged
.iter()
.flat_map(|w| w.removed_fragment_ids.iter().copied())
.collect();
let mut combined: Vec<Fragment> = ds
.manifest
.fragments
.iter()
.filter(|f| !removed.contains(&f.id))
.cloned()
.collect();
for write in staged {
combined.extend(write.new_fragments.iter().cloned());
}
combined
}
fn check_batch_unique_by_keys(
batch: &RecordBatch,
key_columns: &[String],
context: &'static str,
) -> Result<()> {
if key_columns.len() != 1 {
return Err(OmniError::manifest_internal(format!(
"{}: check_batch_unique_by_keys currently supports single-column keys only, got {:?}",
context, key_columns
)));
}
let key_col_name = &key_columns[0];
let column = batch.column_by_name(key_col_name).ok_or_else(|| {
OmniError::manifest_internal(format!(
"{}: source batch missing key column '{}'",
context, key_col_name
))
})?;
let strs = column
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"{}: key column '{}' is not a StringArray (got {:?})",
context,
key_col_name,
column.data_type()
))
})?;
let mut seen: std::collections::HashSet<&str> =
std::collections::HashSet::with_capacity(batch.num_rows());
for i in 0..strs.len() {
if !strs.is_valid(i) {
continue;
}
let v = strs.value(i);
if !seen.insert(v) {
return Err(OmniError::manifest(format!(
"{}: duplicate source row for key '{}' (column '{}'); \
callers must hand in a batch unique by `key_columns` \
— see MR-957",
context, v, key_col_name
)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::StringArray;
use arrow_schema::{DataType, Field, Schema};
fn batch_with_ids(ids: &[&str]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
let col = Arc::new(StringArray::from(ids.to_vec())) as ArrayRef;
RecordBatch::try_new(schema, vec![col]).unwrap()
}
#[test]
fn check_batch_unique_by_keys_passes_when_all_unique() {
let batch = batch_with_ids(&["a", "b", "c"]);
check_batch_unique_by_keys(&batch, &["id".to_string()], "test").unwrap();
}
#[test]
fn check_batch_unique_by_keys_errors_on_duplicate_id() {
let batch = batch_with_ids(&["a", "b", "a"]);
let err = check_batch_unique_by_keys(&batch, &["id".to_string()], "test").unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("duplicate source row for key 'a'"),
"unexpected error: {msg}"
);
assert!(
msg.contains("MR-957"),
"error should reference MR-957: {msg}"
);
}
#[test]
fn check_batch_unique_by_keys_rejects_multi_column_keys() {
let batch = batch_with_ids(&["a"]);
let err =
check_batch_unique_by_keys(&batch, &["id".to_string(), "other".to_string()], "test")
.unwrap_err();
assert!(err.to_string().contains("single-column keys only"));
}
}