use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use arrow_array::{Array, RecordBatch, StringArray, UInt32Array};
use arrow_schema::SchemaRef;
use lance::Dataset;
use omnigraph_compiler::catalog::EdgeType;
use crate::db::{MutationOpKind, SubTableUpdate};
use crate::db::manifest::{
new_sidecar, write_sidecar, RecoverySidecarHandle, SidecarKind, SidecarTablePin,
};
use crate::error::{OmniError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum PendingMode {
Append,
Merge,
}
#[derive(Debug)]
pub(crate) struct PendingTable {
pub(crate) schema: SchemaRef,
pub(crate) mode: PendingMode,
pub(crate) batches: Vec<RecordBatch>,
}
impl PendingTable {
fn new(schema: SchemaRef, mode: PendingMode) -> Self {
Self {
schema,
mode,
batches: Vec::new(),
}
}
fn total_rows(&self) -> usize {
self.batches.iter().map(|b| b.num_rows()).sum()
}
}
#[derive(Debug, Clone)]
pub(crate) struct StagedTablePath {
pub(crate) full_path: String,
pub(crate) table_branch: Option<String>,
}
#[derive(Default)]
pub(crate) struct MutationStaging {
pub(crate) expected_versions: HashMap<String, u64>,
pub(crate) paths: HashMap<String, StagedTablePath>,
pub(crate) pending: HashMap<String, PendingTable>,
pub(crate) inline_committed: HashMap<String, SubTableUpdate>,
pub(crate) op_kinds: HashMap<String, MutationOpKind>,
}
impl MutationStaging {
pub(crate) fn ensure_path(
&mut self,
table_key: &str,
full_path: String,
table_branch: Option<String>,
expected_version: u64,
op_kind: MutationOpKind,
) {
self.paths.entry(table_key.to_string()).or_insert(StagedTablePath {
full_path,
table_branch,
});
self.expected_versions
.entry(table_key.to_string())
.or_insert(expected_version);
self.op_kinds
.entry(table_key.to_string())
.and_modify(|existing| {
if op_kind.strict_pre_stage_version_check()
&& !existing.strict_pre_stage_version_check()
{
*existing = op_kind;
}
})
.or_insert(op_kind);
}
pub(crate) fn append_batch(
&mut self,
table_key: &str,
schema: SchemaRef,
mode: PendingMode,
batch: RecordBatch,
) -> Result<()> {
if batch.num_rows() == 0 {
return Ok(());
}
if let Some(existing) = self.pending.get(table_key) {
if !schemas_compatible(&existing.schema, &batch.schema()) {
return Err(OmniError::manifest(format!(
"table '{}' accumulated mutation batches with mismatched schemas: \
prior batches have {} columns, this batch has {}. \
This typically happens on a blob-bearing table when one \
op uses the full schema (e.g. an `insert`) and another \
omits unassigned blob columns (e.g. an `update` that \
doesn't set every blob property). Split the mutation \
into two queries: one for the inserts, one for the \
updates.",
table_key,
existing.schema.fields().len(),
batch.schema().fields().len(),
)));
}
}
let entry = self
.pending
.entry(table_key.to_string())
.or_insert_with(|| PendingTable::new(schema.clone(), mode));
if mode == PendingMode::Merge {
entry.mode = PendingMode::Merge;
}
entry.batches.push(batch);
Ok(())
}
pub(crate) fn record_inline(&mut self, update: SubTableUpdate) {
self.inline_committed.insert(update.table_key.clone(), update);
}
pub(crate) fn pending_batches(&self, table_key: &str) -> &[RecordBatch] {
self.pending
.get(table_key)
.map(|p| p.batches.as_slice())
.unwrap_or(&[])
}
pub(crate) fn pending_schema(&self, table_key: &str) -> Option<SchemaRef> {
self.pending.get(table_key).map(|p| p.schema.clone())
}
pub(crate) fn is_empty(&self) -> bool {
self.pending.is_empty() && self.inline_committed.is_empty()
}
#[allow(dead_code)]
pub(crate) fn pending_row_count(&self) -> usize {
self.pending.values().map(|p| p.total_rows()).sum()
}
pub(crate) async fn stage_all(
self,
db: &crate::db::Omnigraph,
_branch: Option<&str>,
) -> Result<StagedMutation> {
let MutationStaging {
expected_versions,
paths,
pending,
inline_committed,
op_kinds,
} = self;
let mut staged_entries: Vec<StagedTableEntry> = Vec::with_capacity(pending.len());
for (table_key, table) in pending {
let path = paths.get(&table_key).cloned().ok_or_else(|| {
OmniError::manifest_internal(format!(
"MutationStaging::stage_all: missing path for table '{}'",
table_key
))
})?;
let expected = *expected_versions.get(&table_key).ok_or_else(|| {
OmniError::manifest_internal(format!(
"MutationStaging::stage_all: missing expected version for table '{}'",
table_key
))
})?;
let stage_kind = match table.mode {
PendingMode::Append => crate::db::MutationOpKind::Insert,
PendingMode::Merge => crate::db::MutationOpKind::Merge,
};
let ds = db
.reopen_for_mutation(
&table_key,
&path.full_path,
path.table_branch.as_deref(),
expected,
stage_kind,
)
.await?;
if table.batches.is_empty() {
continue;
}
let combined = match table.mode {
PendingMode::Merge => {
dedupe_merge_batches_by_id(&table.schema, table.batches)?
}
PendingMode::Append => {
if table.batches.len() == 1 {
table.batches.into_iter().next().unwrap()
} else {
arrow_select::concat::concat_batches(
&table.schema,
&table.batches,
)
.map_err(|e| OmniError::Lance(e.to_string()))?
}
}
};
let staged = match table.mode {
PendingMode::Append => {
db.table_store().stage_append(&ds, combined, &[]).await?
}
PendingMode::Merge => {
db.table_store()
.stage_merge_insert(
ds.clone(),
combined,
vec!["id".to_string()],
lance::dataset::WhenMatched::UpdateAll,
lance::dataset::WhenNotMatched::InsertAll,
)
.await?
}
};
staged_entries.push(StagedTableEntry {
table_key,
path,
expected_version: expected,
dataset: ds,
staged_write: staged,
});
}
Ok(StagedMutation {
inline_committed,
staged: staged_entries,
expected_versions,
paths,
op_kinds,
})
}
}
pub(crate) struct StagedMutation {
inline_committed: HashMap<String, SubTableUpdate>,
staged: Vec<StagedTableEntry>,
expected_versions: HashMap<String, u64>,
paths: HashMap<String, StagedTablePath>,
op_kinds: HashMap<String, MutationOpKind>,
}
struct StagedTableEntry {
table_key: String,
path: StagedTablePath,
expected_version: u64,
dataset: lance::Dataset,
staged_write: crate::table_store::StagedWrite,
}
impl StagedMutation {
pub(crate) async fn commit_all(
self,
db: &crate::db::Omnigraph,
branch: Option<&str>,
sidecar_kind: SidecarKind,
actor_id: Option<&str>,
) -> Result<(
Vec<SubTableUpdate>,
HashMap<String, u64>,
Option<RecoverySidecarHandle>,
Vec<tokio::sync::OwnedMutexGuard<()>>,
)> {
let StagedMutation {
inline_committed,
mut staged,
mut expected_versions,
paths,
op_kinds,
} = self;
let mut queue_keys: Vec<(String, Option<String>)> = Vec::with_capacity(
staged.len() + inline_committed.len(),
);
for entry in &staged {
queue_keys.push((entry.table_key.clone(), entry.path.table_branch.clone()));
}
for table_key in inline_committed.keys() {
let path = paths.get(table_key).ok_or_else(|| {
OmniError::manifest_internal(format!(
"StagedMutation::commit_all: missing path for inline-committed table '{}'",
table_key
))
})?;
queue_keys.push((table_key.clone(), path.table_branch.clone()));
}
let guards = db.write_queue().acquire_many(&queue_keys).await;
let snapshot = db.snapshot_for_branch(branch).await?;
for entry in staged.iter_mut() {
let current = snapshot
.entry(&entry.table_key)
.map(|e| e.table_version)
.ok_or_else(|| {
OmniError::manifest_conflict(format!(
"table '{}' missing from manifest at commit time",
entry.table_key,
))
})?;
let strict = op_kinds
.get(&entry.table_key)
.map(|k| k.strict_pre_stage_version_check())
.unwrap_or(false);
if strict && entry.expected_version != current {
return Err(OmniError::manifest_expected_version_mismatch(
entry.table_key.clone(),
entry.expected_version,
current,
));
}
entry.expected_version = current;
expected_versions.insert(entry.table_key.clone(), current);
}
let mut pins: Vec<SidecarTablePin> = Vec::with_capacity(
staged.len() + inline_committed.len(),
);
for entry in &staged {
pins.push(SidecarTablePin {
table_key: entry.table_key.clone(),
table_path: entry.path.full_path.clone(),
expected_version: entry.expected_version,
post_commit_pin: entry.expected_version + 1,
table_branch: entry.path.table_branch.clone(),
});
}
for (table_key, update) in &inline_committed {
let path = paths.get(table_key).ok_or_else(|| {
OmniError::manifest_internal(format!(
"StagedMutation::commit_all: missing path for inline-committed table '{}'",
table_key
))
})?;
let expected = *expected_versions.get(table_key).ok_or_else(|| {
OmniError::manifest_internal(format!(
"StagedMutation::commit_all: missing expected version for inline-committed table '{}'",
table_key
))
})?;
pins.push(SidecarTablePin {
table_key: table_key.clone(),
table_path: path.full_path.clone(),
expected_version: expected,
post_commit_pin: update.table_version,
table_branch: path.table_branch.clone(),
});
}
let sidecar_handle = if pins.is_empty() {
None
} else {
let sidecar = new_sidecar(
sidecar_kind,
branch.map(|s| s.to_string()),
actor_id.map(str::to_string),
pins,
);
Some(write_sidecar(db.root_uri(), db.storage_adapter(), &sidecar).await?)
};
for (table_key, _update) in inline_committed.iter() {
let current = snapshot
.entry(table_key)
.map(|e| e.table_version)
.ok_or_else(|| {
OmniError::manifest_conflict(format!(
"table '{}' missing from manifest at commit time",
table_key,
))
})?;
let expected = expected_versions.get(table_key).copied().ok_or_else(|| {
OmniError::manifest_internal(format!(
"StagedMutation::commit_all: missing expected version for inline-committed table '{}'",
table_key
))
})?;
if expected != current {
return Err(OmniError::manifest_expected_version_mismatch(
table_key.clone(),
expected,
current,
));
}
expected_versions.insert(table_key.clone(), current);
}
let mut updates: Vec<SubTableUpdate> = inline_committed.into_values().collect();
for entry in staged {
let StagedTableEntry {
table_key,
path,
expected_version: _,
dataset,
staged_write,
} = entry;
let new_ds = db
.table_store()
.commit_staged(Arc::new(dataset), staged_write.transaction)
.await?;
let state = db
.table_store()
.table_state(&path.full_path, &new_ds)
.await?;
updates.push(SubTableUpdate {
table_key,
table_version: state.version,
table_branch: path.table_branch.clone(),
row_count: state.row_count,
version_metadata: state.version_metadata,
});
}
Ok((updates, expected_versions, sidecar_handle, guards))
}
}
fn schemas_compatible(a: &SchemaRef, b: &SchemaRef) -> bool {
if a.fields().len() != b.fields().len() {
return false;
}
for (af, bf) in a.fields().iter().zip(b.fields().iter()) {
if af.name() != bf.name() || af.data_type() != bf.data_type() {
return false;
}
}
true
}
fn dedupe_merge_batches_by_id(
schema: &SchemaRef,
batches: Vec<RecordBatch>,
) -> Result<RecordBatch> {
if batches.is_empty() {
return Err(OmniError::manifest_internal(
"dedupe_merge_batches_by_id: batches is empty".to_string(),
));
}
let mut seen: HashSet<String> = HashSet::new();
let mut keep: Vec<Vec<u32>> = vec![Vec::new(); batches.len()];
let mut any_duplicates = false;
for (b_idx, batch) in batches.iter().enumerate().rev() {
let id_col = batch
.column_by_name("id")
.ok_or_else(|| {
OmniError::manifest_internal(
"dedupe_merge_batches_by_id: batch has no 'id' column".to_string(),
)
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
OmniError::manifest_internal(
"dedupe_merge_batches_by_id: 'id' column is not Utf8".to_string(),
)
})?;
for r_idx in (0..batch.num_rows()).rev() {
if !id_col.is_valid(r_idx) {
keep[b_idx].push(r_idx as u32);
continue;
}
let id = id_col.value(r_idx);
if seen.insert(id.to_string()) {
keep[b_idx].push(r_idx as u32);
} else {
any_duplicates = true;
}
}
keep[b_idx].reverse();
}
if !any_duplicates {
if batches.len() == 1 {
return Ok(batches.into_iter().next().unwrap());
}
return arrow_select::concat::concat_batches(schema, &batches)
.map_err(|e| OmniError::Lance(e.to_string()));
}
let mut sliced: Vec<RecordBatch> = Vec::with_capacity(batches.len());
for (b_idx, idxs) in keep.into_iter().enumerate() {
if idxs.is_empty() {
continue;
}
let take_array = UInt32Array::from(idxs);
let columns: Vec<Arc<dyn Array>> = batches[b_idx]
.columns()
.iter()
.map(|col| arrow_select::take::take(col, &take_array, None))
.collect::<std::result::Result<_, _>>()
.map_err(|e| OmniError::Lance(e.to_string()))?;
let new_batch = RecordBatch::try_new(batches[b_idx].schema(), columns)
.map_err(|e| OmniError::Lance(e.to_string()))?;
sliced.push(new_batch);
}
if sliced.is_empty() {
return Err(OmniError::manifest_internal(
"dedupe_merge_batches_by_id: all rows were dropped (unexpected)".to_string(),
));
}
if sliced.len() == 1 {
return Ok(sliced.into_iter().next().unwrap());
}
arrow_select::concat::concat_batches(schema, &sliced)
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub(crate) async fn count_src_per_edge(
db: &crate::db::Omnigraph,
committed_ds: &Dataset,
table_key: &str,
staging: &MutationStaging,
dedupe_key_column: Option<&str>,
) -> Result<HashMap<String, u32>> {
let mut counts: HashMap<String, u32> = HashMap::new();
let pending_batches = staging.pending_batches(table_key);
let pending_keys: Option<HashSet<String>> = match dedupe_key_column {
Some(col) if !pending_batches.is_empty() => {
let mut set = HashSet::new();
for batch in pending_batches {
if let Some(arr) = batch
.column_by_name(col)
.and_then(|c| c.as_any().downcast_ref::<StringArray>())
{
for i in 0..arr.len() {
if arr.is_valid(i) {
set.insert(arr.value(i).to_string());
}
}
}
}
Some(set)
}
_ => None,
};
let projection: Vec<&str> = match dedupe_key_column {
Some(col) if pending_keys.as_ref().is_some_and(|s| !s.is_empty()) => vec!["src", col],
_ => vec!["src"],
};
let committed = db
.table_store()
.scan(committed_ds, Some(&projection), None, None)
.await?;
for batch in &committed {
let srcs = batch
.column_by_name("src")
.ok_or_else(|| OmniError::Lance("missing 'src' column on edge table".into()))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::Lance("'src' column is not Utf8".into()))?;
let key_arr = match (&pending_keys, dedupe_key_column) {
(Some(set), Some(col)) if !set.is_empty() => batch
.column_by_name(col)
.and_then(|c| c.as_any().downcast_ref::<StringArray>()),
_ => None,
};
for i in 0..srcs.len() {
if !srcs.is_valid(i) {
continue;
}
if let (Some(arr), Some(set)) = (key_arr, pending_keys.as_ref()) {
if arr.is_valid(i) && set.contains(arr.value(i)) {
continue;
}
}
*counts.entry(srcs.value(i).to_string()).or_insert(0) += 1;
}
}
match dedupe_key_column {
Some(key_col) => count_pending_src_with_dedupe(pending_batches, key_col, &mut counts)?,
None => count_pending_src_naive(pending_batches, &mut counts),
}
Ok(counts)
}
fn count_pending_src_naive(
pending_batches: &[RecordBatch],
counts: &mut HashMap<String, u32>,
) {
for batch in pending_batches {
let Some(col) = batch.column_by_name("src") else {
continue;
};
let Some(srcs) = col.as_any().downcast_ref::<StringArray>() else {
continue;
};
for i in 0..srcs.len() {
if srcs.is_valid(i) {
*counts.entry(srcs.value(i).to_string()).or_insert(0) += 1;
}
}
}
}
fn count_pending_src_with_dedupe(
pending_batches: &[RecordBatch],
dedupe_key_column: &str,
counts: &mut HashMap<String, u32>,
) -> Result<()> {
let mut seen: HashSet<String> = HashSet::new();
let mut kept_srcs: Vec<String> = Vec::new();
for batch in pending_batches.iter().rev() {
let Some(key_col) = batch.column_by_name(dedupe_key_column) else {
return Err(OmniError::manifest_internal(format!(
"count_pending_src_with_dedupe: pending batch missing dedup key column '{}' \
(schema-compat check at append_batch should have rejected this)",
dedupe_key_column
)));
};
let key_arr = key_col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
OmniError::Lance(format!(
"count_src_per_edge: pending '{}' column is not Utf8",
dedupe_key_column
))
})?;
let src_arr = batch
.column_by_name("src")
.and_then(|c| c.as_any().downcast_ref::<StringArray>());
let Some(srcs) = src_arr else {
continue;
};
for i in (0..batch.num_rows()).rev() {
if !srcs.is_valid(i) {
continue;
}
if !key_arr.is_valid(i) {
kept_srcs.push(srcs.value(i).to_string());
continue;
}
let key = key_arr.value(i);
if seen.insert(key.to_string()) {
kept_srcs.push(srcs.value(i).to_string());
}
}
}
for src in kept_srcs {
*counts.entry(src).or_insert(0) += 1;
}
Ok(())
}
pub(crate) fn enforce_cardinality_bounds(
edge_type: &EdgeType,
counts: &HashMap<String, u32>,
) -> Result<()> {
let card = &edge_type.cardinality;
for (src, count) in counts {
if let Some(max) = card.max {
if *count > max {
return Err(OmniError::manifest(format!(
"@card violation on edge {}: source '{}' has {} edges (max {})",
edge_type.name, src, count, max
)));
}
}
if *count < card.min {
return Err(OmniError::manifest(format!(
"@card violation on edge {}: source '{}' has {} edges (min {})",
edge_type.name, src, count, card.min
)));
}
}
Ok(())
}