use super::*;
const MERGE_STAGE_BATCH_ROWS: usize = 8192;
const MERGE_STAGE_DIR_ENV: &str = "OMNIGRAPH_MERGE_STAGING_DIR";
#[derive(Debug)]
enum CandidateTableState {
AdoptSourceState {
validation_delta: Option<AdoptDelta>,
},
AdoptWithDelta(AdoptDelta),
RewriteMerged(StagedMergeResult),
}
#[derive(Debug)]
struct StagedTable {
_dir: TempDir,
dataset: Dataset,
}
#[derive(Debug)]
struct StagedMergeResult {
delta_staged: Option<StagedTable>,
deleted_ids: Vec<String>,
}
#[derive(Debug)]
struct AdoptDelta {
appends: Option<StagedTable>,
upserts: Option<StagedTable>,
deleted_ids: Vec<String>,
}
#[derive(Debug, Clone)]
struct CursorRow {
id: String,
signature: String,
dataset: Dataset,
batch: RecordBatch,
row_index: usize,
}
impl CursorRow {
fn compute_signature(&self) -> Result<String> {
row_signature(&self.batch, self.row_index)
}
}
struct OrderedTableCursor {
stream: Option<std::pin::Pin<Box<DatasetRecordBatchStream>>>,
dataset: Option<Dataset>,
current_batch: Option<RecordBatch>,
current_row: usize,
peeked: Option<CursorRow>,
eager_signatures: bool,
}
impl OrderedTableCursor {
async fn from_snapshot(snapshot: &Snapshot, table_key: &str) -> Result<Self> {
Self::open(snapshot, table_key, true).await
}
async fn from_snapshot_lazy(snapshot: &Snapshot, table_key: &str) -> Result<Self> {
Self::open(snapshot, table_key, false).await
}
async fn open(snapshot: &Snapshot, table_key: &str, eager_signatures: bool) -> Result<Self> {
let dataset = match snapshot.entry(table_key) {
Some(_) => Some(snapshot.open(table_key).await?),
None => None,
};
Self::from_dataset(dataset, eager_signatures).await
}
async fn from_dataset(dataset: Option<Dataset>, eager_signatures: bool) -> Result<Self> {
let stream = if let Some(ds) = &dataset {
Some(Box::pin(
crate::table_store::TableStore::scan_stream_with(
ds,
None,
None,
Some(vec![ColumnOrdering::asc_nulls_last("id".to_string())]),
true,
|_| Ok(()),
)
.await?,
))
} else {
None
};
Ok(Self {
stream,
dataset,
current_batch: None,
current_row: 0,
peeked: None,
eager_signatures,
})
}
async fn peek_cloned(&mut self) -> Result<Option<CursorRow>> {
if self.peeked.is_none() {
self.peeked = self.next_row().await?;
}
Ok(self.peeked.clone())
}
async fn pop(&mut self) -> Result<Option<CursorRow>> {
if self.peeked.is_some() {
return Ok(self.peeked.take());
}
self.next_row().await
}
async fn next_row(&mut self) -> Result<Option<CursorRow>> {
loop {
if let Some(batch) = &self.current_batch {
if self.current_row < batch.num_rows() {
let row_index = self.current_row;
self.current_row += 1;
let dataset = self.dataset.clone().ok_or_else(|| {
OmniError::manifest("cursor row missing source dataset".to_string())
})?;
let signature = if self.eager_signatures {
row_signature(batch, row_index)?
} else {
String::new()
};
return Ok(Some(CursorRow {
id: row_id_at(batch, row_index)?,
signature,
dataset,
batch: batch.clone(),
row_index,
}));
}
}
let Some(stream) = self.stream.as_mut() else {
return Ok(None);
};
match stream.try_next().await {
Ok(Some(batch)) => {
self.current_batch = Some(batch);
self.current_row = 0;
}
Ok(None) => {
self.stream = None;
self.current_batch = None;
return Ok(None);
}
Err(err) => return Err(OmniError::Lance(err.to_string())),
}
}
}
}
struct StagedTableWriter {
schema: SchemaRef,
dataset_uri: String,
dir: TempDir,
dataset: Option<Dataset>,
buffered_rows: usize,
row_count: u64,
batches: Vec<RecordBatch>,
}
impl StagedTableWriter {
fn new(table_key: &str, schema: SchemaRef) -> Result<Self> {
let dir = merge_stage_tempdir(table_key)?;
let dataset_uri = dir.path().join("table.lance").to_string_lossy().to_string();
Ok(Self {
schema,
dataset_uri,
dir,
dataset: None,
buffered_rows: 0,
row_count: 0,
batches: Vec::new(),
})
}
async fn push_row(&mut self, row: &CursorRow) -> Result<()> {
self.row_count += 1;
self.buffered_rows += 1;
self.batches.push(self.row_batch(row).await?);
if self.buffered_rows >= MERGE_STAGE_BATCH_ROWS {
self.flush().await?;
}
Ok(())
}
async fn row_batch(&self, row: &CursorRow) -> Result<RecordBatch> {
let batch = row.batch.slice(row.row_index, 1);
let has_blob_columns = row
.dataset
.schema()
.fields_pre_order()
.any(|field| field.is_blob());
if has_blob_columns {
return crate::table_store::TableStore::materialize_blob_batch(&row.dataset, batch)
.await;
}
let columns = self
.schema
.fields()
.iter()
.map(|field| {
batch.column_by_name(field.name()).cloned().ok_or_else(|| {
OmniError::Lance(format!("batch missing column '{}'", field.name()))
})
})
.collect::<Result<Vec<_>>>()?;
RecordBatch::try_new(self.schema.clone(), columns)
.map_err(|e| OmniError::Lance(e.to_string()))
}
async fn finish(mut self) -> Result<StagedTable> {
self.flush().await?;
if self.dataset.is_none() {
self.dataset = Some(
crate::table_store::TableStore::create_empty_dataset(
&self.dataset_uri,
&self.schema,
)
.await?,
);
}
Ok(StagedTable {
_dir: self.dir,
dataset: self.dataset.unwrap(),
})
}
async fn flush(&mut self) -> Result<()> {
if self.batches.is_empty() {
return Ok(());
}
let batch = if self.batches.len() == 1 {
self.batches.pop().unwrap()
} else {
let batches = std::mem::take(&mut self.batches);
arrow_select::concat::concat_batches(&self.schema, &batches)
.map_err(|e| OmniError::Lance(e.to_string()))?
};
self.buffered_rows = 0;
let ds = crate::table_store::TableStore::append_or_create_batch(
&self.dataset_uri,
self.dataset.take(),
batch,
)
.await?;
self.dataset = Some(ds);
Ok(())
}
}
fn merge_stage_tempdir(table_key: &str) -> Result<TempDir> {
if let Ok(root) = env::var(MERGE_STAGE_DIR_ENV) {
return TempDirBuilder::new()
.prefix(&format!(
"omnigraph-merge-{}-",
sanitize_table_key(table_key)
))
.tempdir_in(PathBuf::from(root))
.map_err(OmniError::from);
}
TempDirBuilder::new()
.prefix(&format!(
"omnigraph-merge-{}-",
sanitize_table_key(table_key)
))
.tempdir()
.map_err(OmniError::from)
}
fn sanitize_table_key(table_key: &str) -> String {
table_key
.chars()
.map(|ch| match ch {
':' | '/' | '\\' => '-',
other => other,
})
.collect()
}
async fn compute_adopt_delta(
table_key: &str,
catalog: &Catalog,
base_snapshot: &Snapshot,
source_snapshot: &Snapshot,
) -> Result<Option<AdoptDelta>> {
let schema = schema_for_table_key(catalog, table_key)?;
let mut append_writer =
StagedTableWriter::new(&format!("{}_adopt_append", table_key), schema.clone())?;
let mut upsert_writer =
StagedTableWriter::new(&format!("{}_adopt_upsert", table_key), schema)?;
let mut deleted_ids: Vec<String> = Vec::new();
let mut base = OrderedTableCursor::from_snapshot_lazy(base_snapshot, table_key).await?;
let mut source = OrderedTableCursor::from_snapshot_lazy(source_snapshot, table_key).await?;
let mut needs_update = false;
loop {
let base_row = base.peek_cloned().await?;
let source_row = source.peek_cloned().await?;
let next_id = [base_row.as_ref(), source_row.as_ref()]
.into_iter()
.flatten()
.map(|row| row.id.clone())
.min();
let Some(next_id) = next_id else { break };
let base_row = if base_row.as_ref().map(|r| r.id.as_str()) == Some(next_id.as_str()) {
base.pop().await?
} else {
None
};
let source_row = if source_row.as_ref().map(|r| r.id.as_str()) == Some(next_id.as_str()) {
source.pop().await?
} else {
None
};
match (&base_row, &source_row) {
(Some(_), None) => {
deleted_ids.push(next_id);
needs_update = true;
}
(None, Some(src)) => {
append_writer.push_row(src).await?;
needs_update = true;
}
(Some(base), Some(src)) => {
if src.compute_signature()? != base.compute_signature()? {
upsert_writer.push_row(src).await?;
needs_update = true;
}
}
(None, None) => unreachable!(),
}
}
if !needs_update {
return Ok(None);
}
let appends = if append_writer.row_count > 0 {
Some(append_writer.finish().await?)
} else {
None
};
let upserts = if upsert_writer.row_count > 0 {
Some(upsert_writer.finish().await?)
} else {
None
};
Ok(Some(AdoptDelta {
appends,
upserts,
deleted_ids,
}))
}
fn min_cursor_id(
base_row: &Option<CursorRow>,
source_row: &Option<CursorRow>,
target_row: &Option<CursorRow>,
) -> Option<String> {
[base_row.as_ref(), source_row.as_ref(), target_row.as_ref()]
.into_iter()
.flatten()
.map(|row| row.id.clone())
.min()
}
async fn stage_streaming_table_merge(
table_key: &str,
catalog: &Catalog,
base_snapshot: &Snapshot,
source_snapshot: &Snapshot,
target_snapshot: &Snapshot,
conflicts: &mut Vec<MergeConflict>,
) -> Result<Option<StagedMergeResult>> {
let schema = schema_for_table_key(catalog, table_key)?;
let mut delta_writer = StagedTableWriter::new(&format!("{}_delta", table_key), schema)?;
let mut deleted_ids: Vec<String> = Vec::new();
let mut base = OrderedTableCursor::from_snapshot(base_snapshot, table_key).await?;
let mut source = OrderedTableCursor::from_snapshot(source_snapshot, table_key).await?;
let mut target = OrderedTableCursor::from_snapshot(target_snapshot, table_key).await?;
let prior_conflict_count = conflicts.len();
let mut needs_update = false;
loop {
let base_row = base.peek_cloned().await?;
let source_row = source.peek_cloned().await?;
let target_row = target.peek_cloned().await?;
let Some(next_id) = min_cursor_id(&base_row, &source_row, &target_row) else {
break;
};
let base_row = if base_row.as_ref().map(|row| row.id.as_str()) == Some(next_id.as_str()) {
base.pop().await?
} else {
None
};
let source_row = if source_row.as_ref().map(|row| row.id.as_str()) == Some(next_id.as_str())
{
source.pop().await?
} else {
None
};
let target_row = if target_row.as_ref().map(|row| row.id.as_str()) == Some(next_id.as_str())
{
target.pop().await?
} else {
None
};
let base_sig = base_row.as_ref().map(|row| row.signature.as_str());
let source_sig = source_row.as_ref().map(|row| row.signature.as_str());
let target_sig = target_row.as_ref().map(|row| row.signature.as_str());
let source_changed = source_sig != base_sig;
let target_changed = target_sig != base_sig;
let selection = if !source_changed {
target_row.as_ref()
} else if !target_changed {
source_row.as_ref()
} else if source_sig == target_sig {
target_row.as_ref()
} else {
conflicts.push(classify_merge_conflict(
table_key, &next_id, base_sig, source_sig, target_sig,
));
None
};
if conflicts.len() > prior_conflict_count {
continue;
}
if selection.is_none() && target_row.is_some() {
deleted_ids.push(next_id.clone());
needs_update = true;
continue;
}
if let Some(selection) = selection {
if selection.signature.as_str() != target_sig.unwrap_or("") {
delta_writer.push_row(selection).await?;
needs_update = true;
}
}
}
if conflicts.len() > prior_conflict_count {
return Ok(None);
}
if !needs_update {
return Ok(None);
}
let delta_staged = if delta_writer.row_count > 0 {
Some(delta_writer.finish().await?)
} else {
None
};
Ok(Some(StagedMergeResult {
delta_staged,
deleted_ids,
}))
}
fn schema_for_table_key(catalog: &Catalog, table_key: &str) -> Result<SchemaRef> {
if let Some(name) = table_key.strip_prefix("node:") {
return catalog
.node_types
.get(name)
.map(|t| t.arrow_schema.clone())
.ok_or_else(|| OmniError::manifest(format!("unknown node type '{}'", name)));
}
if let Some(name) = table_key.strip_prefix("edge:") {
return catalog
.edge_types
.get(name)
.map(|t| t.arrow_schema.clone())
.ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", name)));
}
Err(OmniError::manifest(format!(
"invalid table key '{}'",
table_key
)))
}
fn same_manifest_state(
left: Option<&crate::db::SubTableEntry>,
right: Option<&crate::db::SubTableEntry>,
) -> bool {
match (left, right) {
(Some(left), Some(right)) => {
left.table_version == right.table_version && left.table_branch == right.table_branch
}
(None, None) => true,
_ => false,
}
}
fn classify_merge_conflict(
table_key: &str,
row_id: &str,
base_sig: Option<&str>,
source_sig: Option<&str>,
target_sig: Option<&str>,
) -> MergeConflict {
let (kind, message) = match (base_sig, source_sig, target_sig) {
(None, Some(_), Some(_)) => (
MergeConflictKind::DivergentInsert,
format!("divergent insert for id '{}'", row_id),
),
(Some(_), None, Some(_)) | (Some(_), Some(_), None) => (
MergeConflictKind::DeleteVsUpdate,
format!("delete/update conflict for id '{}'", row_id),
),
_ => (
MergeConflictKind::DivergentUpdate,
format!("divergent update for id '{}'", row_id),
),
};
MergeConflict {
table_key: table_key.to_string(),
row_id: Some(row_id.to_string()),
kind,
message,
}
}
fn row_signature(batch: &RecordBatch, row: usize) -> Result<String> {
let mut values = Vec::with_capacity(batch.num_columns());
for (field, column) in batch.schema().fields().iter().zip(batch.columns()) {
if field.name().starts_with("_row") {
continue;
}
values.push(
array_value_to_string(column.as_ref(), row)
.map_err(|e| OmniError::Lance(e.to_string()))?,
);
}
Ok(values.join("\u{1f}"))
}
async fn build_merge_changeset(
db: &Omnigraph,
candidates: &HashMap<String, CandidateTableState>,
) -> Result<crate::validate::ChangeSet> {
let catalog = db.catalog();
let mut changeset = crate::validate::ChangeSet::new();
for (table_key, candidate) in candidates {
let projection = validation_projection(&catalog, table_key);
let projection: Vec<&str> = projection.iter().map(String::as_str).collect();
let mut change = crate::validate::TableChange::default();
match candidate {
CandidateTableState::AdoptSourceState {
validation_delta: None,
} => continue,
CandidateTableState::AdoptSourceState {
validation_delta: Some(delta),
}
| CandidateTableState::AdoptWithDelta(delta) => {
if let Some(table) = &delta.appends {
change
.added
.extend(scan_staged_for_validation(db, table, &projection).await?);
}
if let Some(table) = &delta.upserts {
change
.changed
.extend(scan_staged_for_validation(db, table, &projection).await?);
}
change.deleted_ids = delta.deleted_ids.clone();
}
CandidateTableState::RewriteMerged(staged) => {
if let Some(table) = &staged.delta_staged {
change
.changed
.extend(scan_staged_for_validation(db, table, &projection).await?);
}
change.deleted_ids = staged.deleted_ids.clone();
}
}
changeset.insert(table_key.clone(), change);
}
Ok(changeset)
}
fn validation_projection(catalog: &Catalog, table_key: &str) -> Vec<String> {
use omnigraph_compiler::types::{PropType, ScalarType};
let is_heavy = |ty: &PropType| matches!(ty.scalar, ScalarType::Vector(_) | ScalarType::Blob);
let mut cols = vec!["id".to_string()];
if let Some(name) = table_key.strip_prefix("node:") {
if let Some(node_type) = catalog.node_types.get(name) {
for (prop, ty) in &node_type.properties {
if !is_heavy(ty) {
cols.push(prop.clone());
}
}
}
} else if let Some(name) = table_key.strip_prefix("edge:") {
cols.push("src".to_string());
cols.push("dst".to_string());
if let Some(edge_type) = catalog.edge_types.get(name) {
for (prop, ty) in &edge_type.properties {
if !is_heavy(ty) {
cols.push(prop.clone());
}
}
}
}
cols
}
async fn scan_staged_for_validation(
db: &Omnigraph,
table: &StagedTable,
projection: &[&str],
) -> Result<Vec<RecordBatch>> {
let snapshot = SnapshotHandle::new(table.dataset.clone());
let batches = db
.storage()
.scan(&snapshot, Some(projection), None, None)
.await?;
Ok(batches
.into_iter()
.filter(|batch| batch.num_rows() > 0)
.collect())
}
async fn validate_merge_candidates(
db: &Omnigraph,
target_snapshot: &Snapshot,
changeset: &crate::validate::ChangeSet,
) -> Result<()> {
let committed = crate::validate::CommittedState::merge(target_snapshot);
let constraints = crate::validate::constraints_for(&db.catalog());
let violations =
crate::validate::evaluate(&constraints, changeset, &committed, &db.catalog()).await?;
if violations.is_empty() {
Ok(())
} else {
Err(OmniError::MergeConflicts(
violations
.into_iter()
.map(crate::validate::Violation::into_merge_conflict)
.collect(),
))
}
}
fn row_id_at(batch: &RecordBatch, row: usize) -> Result<String> {
let ids = batch
.column_by_name("id")
.ok_or_else(|| OmniError::manifest("batch missing id column".to_string()))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest("id column is not Utf8".to_string()))?;
Ok(ids.value(row).to_string())
}
async fn classify_adopt(
target_db: &Omnigraph,
catalog: &Catalog,
base_snapshot: &Snapshot,
source_snapshot: &Snapshot,
target_snapshot: &Snapshot,
table_key: &str,
) -> Result<CandidateTableState> {
let Some(source_entry) = source_snapshot.entry(table_key) else {
return Ok(CandidateTableState::AdoptSourceState {
validation_delta: None,
});
};
let target_entry = target_snapshot.entry(table_key);
let target_active = target_db.active_branch().await;
let advances_head = match (
target_active.as_deref(),
source_entry.table_branch.as_deref(),
) {
(None, Some(_)) => true,
(Some(target_branch), Some(_)) => {
target_entry.and_then(|e| e.table_branch.as_deref()) == Some(target_branch)
}
_ => false,
};
let validation_delta =
compute_adopt_delta(table_key, catalog, base_snapshot, source_snapshot).await?;
match (advances_head, validation_delta) {
(true, Some(delta)) => Ok(CandidateTableState::AdoptWithDelta(delta)),
(_, validation_delta) => {
Ok(CandidateTableState::AdoptSourceState { validation_delta })
}
}
}
async fn publish_adopted_source_state(
target_db: &Omnigraph,
source_snapshot: &Snapshot,
target_snapshot: &Snapshot,
table_key: &str,
) -> Result<crate::db::SubTableUpdate> {
let source_entry = source_snapshot
.entry(table_key)
.ok_or_else(|| OmniError::manifest(format!("missing source entry for {}", table_key)))?;
let target_entry = target_snapshot.entry(table_key);
let target_active = target_db.active_branch().await;
match (
target_active.as_deref(),
source_entry.table_branch.as_deref(),
) {
(None, None) => Ok(crate::db::SubTableUpdate {
table_key: table_key.to_string(),
table_version: source_entry.table_version,
table_branch: None,
row_count: source_entry.row_count,
version_metadata: source_entry.version_metadata.clone(),
}),
(Some(_target_branch), None) => Ok(crate::db::SubTableUpdate {
table_key: table_key.to_string(),
table_version: source_entry.table_version,
table_branch: None,
row_count: source_entry.row_count,
version_metadata: source_entry.version_metadata.clone(),
}),
(None, Some(_source_branch)) => Ok(crate::db::SubTableUpdate {
table_key: table_key.to_string(),
table_version: target_entry
.map(|e| e.table_version)
.unwrap_or(source_entry.table_version),
table_branch: None,
row_count: source_entry.row_count,
version_metadata: target_entry
.map(|entry| entry.version_metadata.clone())
.unwrap_or_else(|| source_entry.version_metadata.clone()),
}),
(Some(target_branch), Some(source_branch)) => {
if target_entry.and_then(|entry| entry.table_branch.as_deref()) == Some(target_branch) {
Ok(crate::db::SubTableUpdate {
table_key: table_key.to_string(),
table_version: target_entry.unwrap().table_version,
table_branch: Some(target_branch.to_string()),
row_count: source_entry.row_count,
version_metadata: target_entry.unwrap().version_metadata.clone(),
})
} else {
let full_path = format!("{}/{}", target_db.uri(), source_entry.table_path);
let ds = target_db
.fork_dataset_from_entry_state(
table_key,
&full_path,
Some(source_branch),
source_entry.table_version,
target_branch,
)
.await?;
let state = target_db.storage().table_state(&full_path, &ds).await?;
Ok(crate::db::SubTableUpdate {
table_key: table_key.to_string(),
table_version: state.version,
table_branch: Some(target_branch.to_string()),
row_count: state.row_count,
version_metadata: state.version_metadata,
})
}
}
}
}
async fn publish_rewritten_merge_table(
target_db: &Omnigraph,
table_key: &str,
staged: &StagedMergeResult,
) -> Result<crate::db::SubTableUpdate> {
let (mut current_ds, full_path, table_branch) = target_db
.open_for_mutation(table_key, crate::db::MutationOpKind::Merge)
.await?
.require_handle("branch merge");
if let Some(delta) = &staged.delta_staged {
let delta_snapshot = SnapshotHandle::new(delta.dataset.clone());
let batches: Vec<RecordBatch> = target_db
.storage()
.scan_batches_for_rewrite(&delta_snapshot)
.await?
.into_iter()
.filter(|batch| batch.num_rows() > 0)
.collect();
if !batches.is_empty() {
let combined = if batches.len() == 1 {
batches.into_iter().next().unwrap()
} else {
let schema = batches[0].schema();
arrow_select::concat::concat_batches(&schema, &batches)
.map_err(|e| OmniError::Lance(e.to_string()))?
};
let staged_merge = target_db
.storage()
.stage_merge_insert(
current_ds.clone(),
combined,
vec!["id".to_string()],
lance::dataset::WhenMatched::UpdateAll,
lance::dataset::WhenNotMatched::InsertAll,
)
.await?;
current_ds = target_db
.storage()
.commit_staged(current_ds, staged_merge)
.await?;
}
}
crate::failpoints::maybe_fail(crate::failpoints::names::BRANCH_MERGE_REWRITE_AFTER_MERGE_PRE_DELETE)?;
if !staged.deleted_ids.is_empty() {
let escaped: Vec<String> = staged
.deleted_ids
.iter()
.map(|id| format!("'{}'", id.replace('\'', "''")))
.collect();
let filter = format!("id IN ({})", escaped.join(", "));
if let Some(staged_delete) = target_db.storage().stage_delete(¤t_ds, &filter).await? {
current_ds = target_db
.storage()
.commit_staged(current_ds, staged_delete)
.await?;
}
}
crate::failpoints::maybe_fail(crate::failpoints::names::BRANCH_MERGE_REWRITE_AFTER_DELETE_PRE_INDEX)?;
let row_count = target_db
.storage()
.table_state(&full_path, ¤t_ds)
.await?
.row_count;
if row_count > 0 {
target_db
.build_indices_on_dataset(table_key, &mut current_ds)
.await?;
}
let final_state = target_db
.storage()
.table_state(&full_path, ¤t_ds)
.await?;
Ok(crate::db::SubTableUpdate {
table_key: table_key.to_string(),
table_version: final_state.version,
table_branch,
row_count: final_state.row_count,
version_metadata: final_state.version_metadata,
})
}
async fn scan_staged_combined(
target_db: &Omnigraph,
table: &StagedTable,
) -> Result<Option<RecordBatch>> {
crate::instrumentation::record_scan_staged_combined();
let snapshot = SnapshotHandle::new(table.dataset.clone());
let batches: Vec<RecordBatch> = target_db
.storage()
.scan_batches_for_rewrite(&snapshot)
.await?
.into_iter()
.filter(|batch| batch.num_rows() > 0)
.collect();
if batches.is_empty() {
return Ok(None);
}
let combined = if batches.len() == 1 {
batches.into_iter().next().unwrap()
} else {
let schema = batches[0].schema();
arrow_select::concat::concat_batches(&schema, &batches)
.map_err(|e| OmniError::Lance(e.to_string()))?
};
Ok(Some(combined))
}
async fn publish_adopted_delta(
target_db: &Omnigraph,
table_key: &str,
delta: &AdoptDelta,
) -> Result<crate::db::SubTableUpdate> {
let (mut current_ds, full_path, table_branch) = target_db
.open_for_mutation(table_key, crate::db::MutationOpKind::Merge)
.await?
.require_handle("branch merge");
if let Some(append_table) = &delta.appends {
let source = SnapshotHandle::new(append_table.dataset.clone());
let staged = target_db
.storage()
.stage_append_stream(¤t_ds, &source, &[])
.await?;
current_ds = target_db
.storage()
.commit_staged(current_ds, staged)
.await?;
}
crate::failpoints::maybe_fail(crate::failpoints::names::BRANCH_MERGE_ADOPT_AFTER_APPEND_PRE_UPSERT)?;
if let Some(upsert_table) = &delta.upserts {
if let Some(combined) = scan_staged_combined(target_db, upsert_table).await? {
let staged_merge = target_db
.storage()
.stage_merge_insert(
current_ds.clone(),
combined,
vec!["id".to_string()],
lance::dataset::WhenMatched::UpdateAll,
lance::dataset::WhenNotMatched::InsertAll,
)
.await?;
current_ds = target_db
.storage()
.commit_staged(current_ds, staged_merge)
.await?;
}
}
crate::failpoints::maybe_fail(crate::failpoints::names::BRANCH_MERGE_ADOPT_AFTER_UPSERT_PRE_DELETE)?;
if !delta.deleted_ids.is_empty() {
let escaped: Vec<String> = delta
.deleted_ids
.iter()
.map(|id| format!("'{}'", id.replace('\'', "''")))
.collect();
let filter = format!("id IN ({})", escaped.join(", "));
if let Some(staged_delete) = target_db.storage().stage_delete(¤t_ds, &filter).await? {
current_ds = target_db
.storage()
.commit_staged(current_ds, staged_delete)
.await?;
}
}
let final_state = target_db
.storage()
.table_state(&full_path, ¤t_ds)
.await?;
Ok(crate::db::SubTableUpdate {
table_key: table_key.to_string(),
table_version: final_state.version,
table_branch,
row_count: final_state.row_count,
version_metadata: final_state.version_metadata,
})
}
impl Omnigraph {
pub async fn branch_merge(&self, source: &str, target: &str) -> Result<MergeOutcome> {
self.branch_merge_as(source, target, None).await
}
pub async fn branch_merge_as(
&self,
source: &str,
target: &str,
actor_id: Option<&str>,
) -> Result<MergeOutcome> {
self.enforce(
omnigraph_policy::PolicyAction::BranchMerge,
&omnigraph_policy::ResourceScope::BranchTransition {
source: source.to_string(),
target: target.to_string(),
},
actor_id,
)?;
self.ensure_schema_apply_idle("branch_merge").await?;
self.heal_pending_recovery_sidecars().await?;
self.branch_merge_impl(source, target, actor_id).await
}
async fn branch_merge_impl(
&self,
source: &str,
target: &str,
actor_id: Option<&str>,
) -> Result<MergeOutcome> {
if is_internal_system_branch(source) || is_internal_system_branch(target) {
return Err(OmniError::manifest(format!(
"branch_merge does not allow internal system refs ('{}' -> '{}')",
source, target
)));
}
let source_branch = Omnigraph::normalize_branch_name(source)?;
let target_branch = Omnigraph::normalize_branch_name(target)?;
if source_branch == target_branch {
return Err(OmniError::manifest(
"branch_merge requires distinct source and target branches".to_string(),
));
}
let source_head_commit_id = self
.head_commit_id_for_branch(source_branch.as_deref())
.await?
.ok_or_else(|| OmniError::manifest("source branch has no head commit".to_string()))?;
let target_head_commit_id = self
.head_commit_id_for_branch(target_branch.as_deref())
.await?
.ok_or_else(|| OmniError::manifest("target branch has no head commit".to_string()))?;
let base_commit = CommitGraph::merge_base(
self.uri(),
source_branch.as_deref(),
target_branch.as_deref(),
)
.await?
.ok_or_else(|| OmniError::manifest("branches have no common ancestor".to_string()))?;
if source_head_commit_id == target_head_commit_id
|| base_commit.graph_commit_id == source_head_commit_id
{
return Ok(MergeOutcome::AlreadyUpToDate);
}
let is_fast_forward = base_commit.graph_commit_id == target_head_commit_id;
let base_snapshot = ManifestCoordinator::snapshot_at(
self.uri(),
base_commit.manifest_branch.as_deref(),
base_commit.manifest_version,
)
.await?;
let source_snapshot = self
.resolved_target(ReadTarget::Branch(
source_branch.clone().unwrap_or_else(|| "main".to_string()),
))
.await?
.snapshot;
let merge_exclusive = self.merge_exclusive();
let _merge_guard = merge_exclusive.lock().await;
let previous_branch = self.active_branch().await;
let previous = self
.swap_coordinator_for_branch(target_branch.as_deref())
.await?;
let merge_result = self
.branch_merge_on_current_target(
&base_snapshot,
&source_snapshot,
&target_head_commit_id,
&source_head_commit_id,
is_fast_forward,
actor_id,
)
.await;
self.restore_coordinator(previous).await;
if previous_branch == target_branch {
if let Err(refresh_err) = self.refresh_coordinator_only().await {
if merge_result.is_ok() {
return Err(refresh_err);
}
tracing::warn!(
error = %refresh_err,
"post-merge coordinator refresh failed on the error path; \
the next op or open will re-sync"
);
}
}
merge_result
}
async fn branch_merge_on_current_target(
&self,
base_snapshot: &Snapshot,
source_snapshot: &Snapshot,
target_head_commit_id: &str,
source_head_commit_id: &str,
is_fast_forward: bool,
actor_id: Option<&str>,
) -> Result<MergeOutcome> {
let target_snapshot = self.snapshot().await;
let mut table_keys = HashSet::new();
for entry in base_snapshot.entries() {
table_keys.insert(entry.table_key.clone());
}
for entry in source_snapshot.entries() {
table_keys.insert(entry.table_key.clone());
}
for entry in target_snapshot.entries() {
table_keys.insert(entry.table_key.clone());
}
let mut ordered_table_keys: Vec<String> = table_keys.into_iter().collect();
ordered_table_keys.sort();
let mut conflicts = Vec::new();
let mut candidates: HashMap<String, CandidateTableState> = HashMap::new();
for table_key in &ordered_table_keys {
let base_entry = base_snapshot.entry(table_key);
let source_entry = source_snapshot.entry(table_key);
let target_entry = target_snapshot.entry(table_key);
if same_manifest_state(source_entry, target_entry) {
continue;
}
if same_manifest_state(base_entry, source_entry) {
continue;
}
if same_manifest_state(base_entry, target_entry) {
let candidate = classify_adopt(
self,
&self.catalog(),
base_snapshot,
source_snapshot,
&target_snapshot,
table_key,
)
.await?;
candidates.insert(table_key.clone(), candidate);
continue;
}
if let Some(staged) = stage_streaming_table_merge(
table_key,
&self.catalog(),
base_snapshot,
source_snapshot,
&target_snapshot,
&mut conflicts,
)
.await?
{
candidates.insert(
table_key.clone(),
CandidateTableState::RewriteMerged(staged),
);
}
}
if !conflicts.is_empty() {
return Err(OmniError::MergeConflicts(conflicts));
}
let changeset = build_merge_changeset(self, &candidates).await?;
validate_merge_candidates(self, &target_snapshot, &changeset).await?;
let active_branch_for_keys = self.active_branch().await;
let merge_queue_keys: Vec<(String, Option<String>)> = ordered_table_keys
.iter()
.filter(|table_key| {
matches!(
candidates.get(*table_key),
Some(CandidateTableState::RewriteMerged(_))
| Some(CandidateTableState::AdoptSourceState { .. })
| Some(CandidateTableState::AdoptWithDelta(_))
)
})
.map(|table_key| (table_key.clone(), active_branch_for_keys.clone()))
.collect();
let _merge_queue_guards = self.write_queue().acquire_many(&merge_queue_keys).await;
let post_queue_snapshot = self.snapshot().await;
for table_key in &ordered_table_keys {
let Some(candidate) = candidates.get(table_key) else {
continue;
};
if !matches!(
candidate,
CandidateTableState::RewriteMerged(_)
| CandidateTableState::AdoptSourceState { .. }
| CandidateTableState::AdoptWithDelta(_)
) {
continue;
}
let expected = target_snapshot.entry(table_key).map(|e| e.table_version);
let current = post_queue_snapshot
.entry(table_key)
.map(|e| e.table_version);
if expected != current {
return Err(OmniError::manifest_expected_version_mismatch(
table_key.clone(),
expected.unwrap_or(0),
current.unwrap_or(0),
));
}
}
let recovery_pins: Vec<crate::db::manifest::SidecarTablePin> = ordered_table_keys
.iter()
.filter_map(|table_key| {
let candidate = candidates.get(table_key)?;
if !matches!(
candidate,
CandidateTableState::RewriteMerged(_) | CandidateTableState::AdoptWithDelta(_)
) {
return None;
}
let entry = target_snapshot.entry(table_key)?;
Some(crate::db::manifest::SidecarTablePin {
table_key: table_key.clone(),
table_path: self.storage().dataset_uri(&entry.table_path),
expected_version: entry.table_version,
post_commit_pin: entry.table_version + 1,
confirmed_version: None,
table_branch: active_branch_for_keys.clone(),
})
})
.collect();
let mut recovery: Option<(
crate::db::manifest::RecoverySidecar,
crate::db::manifest::RecoverySidecarHandle,
)> = if recovery_pins.is_empty() {
None
} else {
let target_branch = active_branch_for_keys.clone();
let mut sidecar = crate::db::manifest::new_sidecar(
crate::db::manifest::SidecarKind::BranchMerge,
target_branch,
actor_id.map(str::to_string),
recovery_pins,
);
sidecar.merge_source_commit_id = Some(source_head_commit_id.to_string());
let handle = crate::db::manifest::write_sidecar(
self.root_uri(),
self.storage_adapter(),
&sidecar,
)
.await?;
Some((sidecar, handle))
};
let mut updates = Vec::new();
let mut changed_edge_tables = false;
for table_key in &ordered_table_keys {
let Some(candidate_state) = candidates.get(table_key) else {
continue;
};
let update = match candidate_state {
CandidateTableState::AdoptSourceState { .. } => {
publish_adopted_source_state(self, source_snapshot, &target_snapshot, table_key)
.await?
}
CandidateTableState::AdoptWithDelta(delta) => {
publish_adopted_delta(self, table_key, delta).await?
}
CandidateTableState::RewriteMerged(staged) => {
publish_rewritten_merge_table(self, table_key, staged).await?
}
};
if table_key.starts_with("edge:") {
changed_edge_tables = true;
}
updates.push(update);
}
if let Some((sidecar, _)) = recovery.as_mut() {
let confirmed_versions: std::collections::HashMap<String, u64> = updates
.iter()
.map(|u| (u.table_key.clone(), u.table_version))
.collect();
crate::db::manifest::confirm_sidecar_phase_b(
self.root_uri(),
self.storage_adapter(),
sidecar,
&confirmed_versions,
)
.await?;
}
crate::failpoints::maybe_fail(crate::failpoints::names::BRANCH_MERGE_POST_PHASE_B_PRE_MANIFEST_COMMIT)?;
let _ = target_head_commit_id;
self.commit_merge_with_actor(&updates, source_head_commit_id, actor_id)
.await?;
if let Some((_, handle)) = recovery {
if let Err(err) =
crate::db::manifest::delete_sidecar(&handle, self.storage_adapter()).await
{
tracing::warn!(
error = %err,
operation_id = handle.operation_id.as_str(),
"recovery sidecar cleanup failed; the next open's recovery sweep will resolve it"
);
}
}
if changed_edge_tables {
self.invalidate_graph_index().await;
}
Ok(if is_fast_forward {
MergeOutcome::FastForward
} else {
MergeOutcome::Merged
})
}
}