use crate::{
db::commit::{
PreparedRowCommitOp,
marker::{COMMIT_ID_BYTES, CommitMarker, CommitRowOp, generate_commit_id},
store::{CommitStore, with_commit_store, with_commit_store_infallible},
},
error::InternalError,
};
use std::panic::{AssertUnwindSafe, catch_unwind};
enum ApplyRollback {
None,
Closure(Box<dyn FnOnce()>),
SinglePreparedRow(PreparedRowCommitOp),
}
pub(crate) struct CommitApplyGuard {
phase: &'static str,
finished: bool,
rollback: ApplyRollback,
}
impl CommitApplyGuard {
pub(crate) const fn new(phase: &'static str) -> Self {
Self {
phase,
finished: false,
rollback: ApplyRollback::None,
}
}
pub(crate) fn record_rollback(&mut self, rollback: impl FnOnce() + 'static) {
debug_assert!(
matches!(self.rollback, ApplyRollback::None),
"commit apply guard currently owns exactly one rollback closure",
);
if matches!(self.rollback, ApplyRollback::None) {
self.rollback = ApplyRollback::Closure(Box::new(rollback));
}
}
pub(crate) fn record_single_row_rollback(&mut self, rollback: PreparedRowCommitOp) {
debug_assert!(
matches!(self.rollback, ApplyRollback::None),
"commit apply guard currently owns exactly one rollback payload",
);
if matches!(self.rollback, ApplyRollback::None) {
self.rollback = ApplyRollback::SinglePreparedRow(rollback);
}
}
pub(crate) fn finish(mut self) -> Result<(), InternalError> {
if self.finished {
return Err(InternalError::executor_invariant(format!(
"commit apply guard invariant violated: finish called twice ({})",
self.phase
)));
}
self.finished = true;
self.rollback = ApplyRollback::None;
Ok(())
}
fn rollback_best_effort(&mut self) {
if self.finished {
return;
}
match std::mem::replace(&mut self.rollback, ApplyRollback::None) {
ApplyRollback::None => {}
ApplyRollback::Closure(rollback) => {
let _ = catch_unwind(AssertUnwindSafe(rollback));
}
ApplyRollback::SinglePreparedRow(rollback) => {
let _ = catch_unwind(AssertUnwindSafe(|| rollback.apply()));
}
}
}
}
impl Drop for CommitApplyGuard {
fn drop(&mut self) {
if !self.finished {
self.rollback_best_effort();
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct CommitGuard {
commit_id: [u8; COMMIT_ID_BYTES],
}
impl CommitGuard {
const fn for_persisted_id(commit_id: [u8; COMMIT_ID_BYTES]) -> Self {
Self { commit_id }
}
fn clear(self) {
let _ = self;
with_commit_store_infallible(CommitStore::clear_infallible);
}
}
pub(crate) fn begin_commit(marker: CommitMarker) -> Result<CommitGuard, InternalError> {
with_commit_store(|store| {
let commit_id = marker.id;
store.set_if_empty(&marker)?;
Ok(CommitGuard::for_persisted_id(commit_id))
})
}
pub(crate) fn begin_single_row_commit(row_op: CommitRowOp) -> Result<CommitGuard, InternalError> {
with_commit_store(|store| {
let commit_id = generate_commit_id()?;
store.set_single_row_op_if_empty(commit_id, &row_op)?;
Ok(CommitGuard::for_persisted_id(commit_id))
})
}
pub(crate) fn begin_commit_with_migration_state(
marker: CommitMarker,
migration_state_bytes: Vec<u8>,
) -> Result<CommitGuard, InternalError> {
with_commit_store(|store| {
if !store.marker_is_empty()? {
return Err(InternalError::store_invariant(
"commit marker already present before begin",
));
}
let commit_id = marker.id;
store.set_with_migration_state(&marker, migration_state_bytes)?;
Ok(CommitGuard::for_persisted_id(commit_id))
})
}
pub(crate) fn finish_commit(
mut guard: CommitGuard,
apply: impl FnOnce(&mut CommitGuard) -> Result<(), InternalError>,
) -> Result<(), InternalError> {
let result = apply(&mut guard);
let commit_id = guard.commit_id;
if result.is_ok() {
guard.clear();
assert!(
with_commit_store_infallible(|store| store.is_empty()),
"commit marker must be cleared after successful finish_commit (commit_id={commit_id:?})"
);
} else {
assert!(
with_commit_store_infallible(|store| !store.is_empty()),
"commit marker must remain persisted after failed finish_commit (commit_id={commit_id:?})"
);
}
result
}