use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::storage_layer::{SnapshotHandle, StagedHandle};
use arrow_array::{Array, RecordBatch, StringArray, UInt32Array};
use arrow_schema::SchemaRef;
use futures::stream::StreamExt;
use crate::db::manifest::{
RecoverySidecarHandle, SidecarKind, SidecarTablePin, new_sidecar, write_sidecar,
};
use crate::db::{MutationOpKind, SubTableUpdate};
use crate::error::{OmniError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum PendingMode {
Append,
Merge,
Overwrite,
}
#[derive(Debug)]
pub(crate) struct PendingTable {
pub(crate) schema: SchemaRef,
pub(crate) mode: PendingMode,
pub(crate) batches: Vec<RecordBatch>,
}
impl PendingTable {
fn new(schema: SchemaRef, mode: PendingMode) -> Self {
Self {
schema,
mode,
batches: Vec::new(),
}
}
fn total_rows(&self) -> usize {
self.batches.iter().map(|b| b.num_rows()).sum()
}
}
#[derive(Debug, Clone)]
pub(crate) struct StagedTablePath {
pub(crate) full_path: String,
pub(crate) table_branch: Option<String>,
}
#[derive(Default)]
pub(crate) struct MutationStaging {
pub(crate) expected_versions: HashMap<String, u64>,
pub(crate) paths: HashMap<String, StagedTablePath>,
pub(crate) pending: HashMap<String, PendingTable>,
pub(crate) delete_predicates: HashMap<String, Vec<String>>,
pub(crate) deleted_ids: HashMap<String, Vec<String>>,
pub(crate) op_kinds: HashMap<String, MutationOpKind>,
}
impl MutationStaging {
pub(crate) fn ensure_path(
&mut self,
table_key: &str,
full_path: String,
table_branch: Option<String>,
expected_version: u64,
op_kind: MutationOpKind,
) {
self.paths
.entry(table_key.to_string())
.or_insert(StagedTablePath {
full_path,
table_branch,
});
self.expected_versions
.entry(table_key.to_string())
.or_insert(expected_version);
self.op_kinds
.entry(table_key.to_string())
.and_modify(|existing| {
if op_kind.strict_pre_stage_version_check()
&& !existing.strict_pre_stage_version_check()
{
*existing = op_kind;
}
})
.or_insert(op_kind);
}
pub(crate) fn append_batch(
&mut self,
table_key: &str,
schema: SchemaRef,
mode: PendingMode,
batch: RecordBatch,
) -> Result<()> {
if batch.num_rows() == 0 && mode != PendingMode::Overwrite {
return Ok(());
}
if let Some(existing) = self.pending.get(table_key) {
if existing.mode == PendingMode::Overwrite || mode == PendingMode::Overwrite {
if existing.mode != mode {
return Err(OmniError::manifest_internal(format!(
"table '{}' cannot mix overwrite staging with append/merge staging",
table_key
)));
}
}
if !schemas_compatible(&existing.schema, &batch.schema()) {
return Err(OmniError::manifest(format!(
"table '{}' accumulated mutation batches with mismatched schemas: \
prior batches have {} columns, this batch has {}. \
This typically happens on a blob-bearing table when one \
op uses the full schema (e.g. an `insert`) and another \
omits unassigned blob columns (e.g. an `update` that \
doesn't set every blob property). Split the mutation \
into two queries: one for the inserts, one for the \
updates.",
table_key,
existing.schema.fields().len(),
batch.schema().fields().len(),
)));
}
}
let entry = self
.pending
.entry(table_key.to_string())
.or_insert_with(|| PendingTable::new(schema.clone(), mode));
if mode == PendingMode::Merge && entry.mode == PendingMode::Append {
entry.mode = PendingMode::Merge;
}
entry.batches.push(batch);
Ok(())
}
pub(crate) fn record_delete(&mut self, table_key: &str, predicate: String) {
self.delete_predicates
.entry(table_key.to_string())
.or_default()
.push(predicate);
}
pub(crate) fn record_deleted_ids(&mut self, table_key: &str, ids: &[String]) {
if ids.is_empty() {
return;
}
self.deleted_ids
.entry(table_key.to_string())
.or_default()
.extend(ids.iter().cloned());
}
pub(crate) fn recorded_delete_predicates(&self, table_key: &str) -> &[String] {
self.delete_predicates
.get(table_key)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub(crate) fn pending_batches(&self, table_key: &str) -> &[RecordBatch] {
self.pending
.get(table_key)
.map(|p| p.batches.as_slice())
.unwrap_or(&[])
}
pub(crate) fn to_changeset(&self) -> crate::validate::ChangeSet {
let mut changeset = crate::validate::ChangeSet::new();
for table_key in self.pending.keys() {
let batches = self.pending_batches(table_key);
if batches.is_empty() {
continue;
}
let mut change = crate::validate::TableChange::default();
change.changed.extend(batches.iter().cloned());
changeset.insert(table_key.clone(), change);
}
for (table_key, ids) in &self.deleted_ids {
if ids.is_empty() {
continue;
}
changeset.entry(table_key.clone()).or_default().deleted_ids = ids.clone();
}
changeset
}
pub(crate) fn pending_schema(&self, table_key: &str) -> Option<SchemaRef> {
self.pending.get(table_key).map(|p| p.schema.clone())
}
pub(crate) fn is_empty(&self) -> bool {
self.pending.is_empty() && self.delete_predicates.is_empty()
}
#[allow(dead_code)]
pub(crate) fn pending_row_count(&self) -> usize {
self.pending.values().map(|p| p.total_rows()).sum()
}
pub(crate) async fn stage_all(
self,
db: &crate::db::Omnigraph,
branch: Option<&str>,
) -> Result<StagedMutation> {
self.stage_all_with_concurrency(db, branch, 1).await
}
pub(crate) async fn stage_all_with_concurrency(
self,
db: &crate::db::Omnigraph,
_branch: Option<&str>,
concurrency: usize,
) -> Result<StagedMutation> {
let MutationStaging {
expected_versions,
paths,
pending,
delete_predicates,
deleted_ids: _,
op_kinds,
} = self;
let mut stage_inputs: Vec<(String, PendingTable, StagedTablePath, u64)> =
Vec::with_capacity(pending.len());
for (table_key, table) in pending {
let path = paths.get(&table_key).cloned().ok_or_else(|| {
OmniError::manifest_internal(format!(
"MutationStaging::stage_all: missing path for table '{}'",
table_key
))
})?;
let expected = *expected_versions.get(&table_key).ok_or_else(|| {
OmniError::manifest_internal(format!(
"MutationStaging::stage_all: missing expected version for table '{}'",
table_key
))
})?;
stage_inputs.push((table_key, table, path, expected));
}
let concurrency = concurrency.min(stage_inputs.len()).max(1);
let mut staged_entries: Vec<StagedTableEntry> = futures::stream::iter(
stage_inputs.into_iter().map(
|(table_key, table, path, expected)| async move {
stage_pending_table(db, table_key, table, path, expected).await
},
),
)
.buffered(concurrency)
.collect::<Vec<Result<Option<StagedTableEntry>>>>()
.await
.into_iter()
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();
for (table_key, predicates) in delete_predicates {
let path = paths.get(&table_key).cloned().ok_or_else(|| {
OmniError::manifest_internal(format!(
"MutationStaging::stage_all: missing path for delete table '{}'",
table_key
))
})?;
let expected = *expected_versions.get(&table_key).ok_or_else(|| {
OmniError::manifest_internal(format!(
"MutationStaging::stage_all: missing expected version for delete table '{}'",
table_key
))
})?;
let combined = if predicates.len() == 1 {
predicates.into_iter().next().unwrap()
} else {
predicates
.iter()
.map(|p| format!("({})", p))
.collect::<Vec<_>>()
.join(" OR ")
};
if let Some(entry) =
stage_delete_table(db, table_key, combined, path, expected).await?
{
staged_entries.push(entry);
}
}
Ok(StagedMutation {
staged: staged_entries,
expected_versions,
op_kinds,
})
}
}
async fn stage_pending_table(
db: &crate::db::Omnigraph,
table_key: String,
table: PendingTable,
path: StagedTablePath,
expected: u64,
) -> Result<Option<StagedTableEntry>> {
let stage_kind = match table.mode {
PendingMode::Append => crate::db::MutationOpKind::Insert,
PendingMode::Merge => crate::db::MutationOpKind::Merge,
PendingMode::Overwrite => crate::db::MutationOpKind::SchemaRewrite,
};
let ds = db
.reopen_for_mutation(
&table_key,
&path.full_path,
path.table_branch.as_deref(),
expected,
stage_kind,
)
.await?;
if table.batches.is_empty() {
return Ok(None);
}
let combined = match table.mode {
PendingMode::Merge => dedupe_merge_batches_by_id(&table.schema, table.batches)?,
PendingMode::Append | PendingMode::Overwrite => {
if table.batches.len() == 1 {
table.batches.into_iter().next().unwrap()
} else {
arrow_select::concat::concat_batches(&table.schema, &table.batches)
.map_err(|e| OmniError::Lance(e.to_string()))?
}
}
};
let staged = match table.mode {
PendingMode::Append => db.storage().stage_append(&ds, combined, &[]).await?,
PendingMode::Merge => {
db.storage()
.stage_merge_insert(
ds.clone(),
combined,
vec!["id".to_string()],
lance::dataset::WhenMatched::UpdateAll,
lance::dataset::WhenNotMatched::InsertAll,
)
.await?
}
PendingMode::Overwrite => db.storage().stage_overwrite(&ds, combined).await?,
};
Ok(Some(StagedTableEntry {
table_key,
path,
expected_version: expected,
dataset: ds,
staged_write: staged,
}))
}
async fn stage_delete_table(
db: &crate::db::Omnigraph,
table_key: String,
predicate: String,
path: StagedTablePath,
expected: u64,
) -> Result<Option<StagedTableEntry>> {
let ds = db
.reopen_for_mutation(
&table_key,
&path.full_path,
path.table_branch.as_deref(),
expected,
crate::db::MutationOpKind::Delete,
)
.await?;
match db.storage().stage_delete(&ds, &predicate).await? {
Some(staged) => Ok(Some(StagedTableEntry {
table_key,
path,
expected_version: expected,
dataset: ds,
staged_write: staged,
})),
None => Ok(None),
}
}
pub(crate) struct StagedMutation {
staged: Vec<StagedTableEntry>,
expected_versions: HashMap<String, u64>,
op_kinds: HashMap<String, MutationOpKind>,
}
struct StagedTableEntry {
table_key: String,
path: StagedTablePath,
expected_version: u64,
dataset: SnapshotHandle,
staged_write: StagedHandle,
}
pub(crate) struct CommittedMutation {
pub(crate) updates: Vec<SubTableUpdate>,
pub(crate) expected_versions: HashMap<String, u64>,
pub(crate) sidecar_handle: Option<RecoverySidecarHandle>,
pub(crate) guards: Vec<tokio::sync::OwnedMutexGuard<()>>,
pub(crate) committed_handles: HashMap<String, SnapshotHandle>,
}
impl StagedMutation {
pub(crate) async fn commit_all(
self,
db: &crate::db::Omnigraph,
branch: Option<&str>,
sidecar_kind: SidecarKind,
actor_id: Option<&str>,
held_guards: Option<(
Vec<(String, Option<String>)>,
Vec<tokio::sync::OwnedMutexGuard<()>>,
)>,
txn: Option<&crate::db::WriteTxn>,
) -> Result<CommittedMutation> {
let StagedMutation {
mut staged,
mut expected_versions,
op_kinds,
} = self;
let mut queue_keys: Vec<(String, Option<String>)> =
Vec::with_capacity(staged.len());
for entry in &staged {
queue_keys.push((entry.table_key.clone(), entry.path.table_branch.clone()));
}
let guards = match held_guards {
Some((acquired_keys, guards)) => {
let held: std::collections::HashSet<&(String, Option<String>)> =
acquired_keys.iter().collect();
if let Some(missing) = queue_keys.iter().find(|k| !held.contains(k)) {
return Err(OmniError::manifest_internal(format!(
"commit_all: pre-held write-queue guards do not cover touched table \
'{}' on branch {:?} — the caller's up-front acquisition set diverged \
from the staged/inline set (a touched-table-set bug)",
missing.0, missing.1
)));
}
guards
}
None => db.write_queue().acquire_many(&queue_keys).await,
};
let snapshot = match txn {
Some(_) => db.occ_snapshot_for_branch(branch).await?,
None => db.fresh_snapshot_for_branch(branch).await?,
};
for entry in staged.iter_mut() {
let current = snapshot
.entry(&entry.table_key)
.map(|e| e.table_version)
.ok_or_else(|| {
OmniError::manifest_conflict(format!(
"table '{}' missing from manifest at commit time",
entry.table_key,
))
})?;
let strict = op_kinds
.get(&entry.table_key)
.map(|k| k.strict_pre_stage_version_check())
.unwrap_or(false);
if strict && entry.expected_version != current {
return Err(OmniError::manifest_expected_version_mismatch(
entry.table_key.clone(),
entry.expected_version,
current,
));
}
let head = entry
.dataset
.dataset()
.latest_version_id()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
if head < current {
return Err(OmniError::manifest_internal(format!(
"table '{}' Lance HEAD version {} is behind manifest version {}",
entry.table_key, head, current
)));
}
if head > current {
let action = match crate::db::manifest::list_sidecars(
db.root_uri(),
db.storage_adapter(),
)
.await
{
Ok(sidecars) => {
let covered = sidecars.iter().any(|sidecar| {
sidecar.tables.iter().any(|pin| {
pin.table_key == entry.table_key
&& pin.table_branch == entry.path.table_branch
})
});
if covered {
"a pending recovery sidecar requires rollback — reopen the \
graph read-write (e.g. restart the server) to recover"
.to_string()
} else {
"run `omnigraph repair` before writing".to_string()
}
}
Err(list_err) => format!(
"could not classify the drift (sidecar listing failed: {}); \
run `omnigraph repair`, or reopen the graph read-write if \
repair reports a pending recovery sidecar",
list_err
),
};
return Err(OmniError::manifest_conflict(format!(
"table '{}' has Lance HEAD version {} ahead of manifest version {}; {}",
entry.table_key, head, current, action
)));
}
entry.expected_version = current;
expected_versions.insert(entry.table_key.clone(), current);
}
let mut pins: Vec<SidecarTablePin> = Vec::with_capacity(staged.len());
for entry in &staged {
pins.push(SidecarTablePin {
table_key: entry.table_key.clone(),
table_path: entry.path.full_path.clone(),
expected_version: entry.expected_version,
post_commit_pin: entry.expected_version + 1,
confirmed_version: None,
table_branch: entry.path.table_branch.clone(),
});
}
let sidecar_handle = if pins.is_empty() {
None
} else {
let sidecar = new_sidecar(
sidecar_kind,
branch.map(|s| s.to_string()),
actor_id.map(str::to_string),
pins,
);
Some(write_sidecar(db.root_uri(), db.storage_adapter(), &sidecar).await?)
};
let mut updates: Vec<SubTableUpdate> = Vec::with_capacity(staged.len());
let mut committed_handles: HashMap<String, SnapshotHandle> =
HashMap::with_capacity(staged.len());
for entry in staged {
let StagedTableEntry {
table_key,
path,
expected_version: _,
dataset,
staged_write,
} = entry;
let new_ds = db.storage().commit_staged(dataset, staged_write).await?;
let state = db.storage().table_state(&path.full_path, &new_ds).await?;
updates.push(SubTableUpdate {
table_key: table_key.clone(),
table_version: state.version,
table_branch: path.table_branch.clone(),
row_count: state.row_count,
version_metadata: state.version_metadata,
});
committed_handles.insert(table_key, new_ds);
}
Ok(CommittedMutation {
updates,
expected_versions,
sidecar_handle,
guards,
committed_handles,
})
}
}
fn schemas_compatible(a: &SchemaRef, b: &SchemaRef) -> bool {
if a.fields().len() != b.fields().len() {
return false;
}
for (af, bf) in a.fields().iter().zip(b.fields().iter()) {
if af.name() != bf.name() || af.data_type() != bf.data_type() {
return false;
}
}
true
}
fn dedupe_merge_batches_by_id(
schema: &SchemaRef,
batches: Vec<RecordBatch>,
) -> Result<RecordBatch> {
if batches.is_empty() {
return Err(OmniError::manifest_internal(
"dedupe_merge_batches_by_id: batches is empty".to_string(),
));
}
let mut seen: HashSet<String> = HashSet::new();
let mut keep: Vec<Vec<u32>> = vec![Vec::new(); batches.len()];
let mut any_duplicates = false;
for (b_idx, batch) in batches.iter().enumerate().rev() {
let id_col = batch
.column_by_name("id")
.ok_or_else(|| {
OmniError::manifest_internal(
"dedupe_merge_batches_by_id: batch has no 'id' column".to_string(),
)
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
OmniError::manifest_internal(
"dedupe_merge_batches_by_id: 'id' column is not Utf8".to_string(),
)
})?;
for r_idx in (0..batch.num_rows()).rev() {
if !id_col.is_valid(r_idx) {
keep[b_idx].push(r_idx as u32);
continue;
}
let id = id_col.value(r_idx);
if seen.insert(id.to_string()) {
keep[b_idx].push(r_idx as u32);
} else {
any_duplicates = true;
}
}
keep[b_idx].reverse();
}
if !any_duplicates {
if batches.len() == 1 {
return Ok(batches.into_iter().next().unwrap());
}
return arrow_select::concat::concat_batches(schema, &batches)
.map_err(|e| OmniError::Lance(e.to_string()));
}
let mut sliced: Vec<RecordBatch> = Vec::with_capacity(batches.len());
for (b_idx, idxs) in keep.into_iter().enumerate() {
if idxs.is_empty() {
continue;
}
let take_array = UInt32Array::from(idxs);
let columns: Vec<Arc<dyn Array>> = batches[b_idx]
.columns()
.iter()
.map(|col| arrow_select::take::take(col, &take_array, None))
.collect::<std::result::Result<_, _>>()
.map_err(|e| OmniError::Lance(e.to_string()))?;
let new_batch = RecordBatch::try_new(batches[b_idx].schema(), columns)
.map_err(|e| OmniError::Lance(e.to_string()))?;
sliced.push(new_batch);
}
if sliced.is_empty() {
return Err(OmniError::manifest_internal(
"dedupe_merge_batches_by_id: all rows were dropped (unexpected)".to_string(),
));
}
if sliced.len() == 1 {
return Ok(sliced.into_iter().next().unwrap());
}
arrow_select::concat::concat_batches(schema, &sliced)
.map_err(|e| OmniError::Lance(e.to_string()))
}