use crate::{
db::commit::{
marker::CommitMarker,
store::{CommitStore, with_commit_store, with_commit_store_infallible},
},
error::InternalError,
};
use std::panic::{AssertUnwindSafe, catch_unwind};
pub(crate) struct CommitApplyGuard {
phase: &'static str,
finished: bool,
rollbacks: Vec<Box<dyn FnOnce()>>,
}
impl CommitApplyGuard {
pub(crate) const fn new(phase: &'static str) -> Self {
Self {
phase,
finished: false,
rollbacks: Vec::new(),
}
}
pub(crate) fn record_rollback(&mut self, rollback: impl FnOnce() + 'static) {
self.rollbacks.push(Box::new(rollback));
}
pub(crate) fn finish(mut self) -> Result<(), InternalError> {
if self.finished {
return Err(InternalError::executor_invariant(format!(
"commit apply guard invariant violated: finish called twice ({})",
self.phase
)));
}
self.finished = true;
self.rollbacks.clear();
Ok(())
}
fn rollback_best_effort(&mut self) {
if self.finished {
return;
}
while let Some(rollback) = self.rollbacks.pop() {
let _ = catch_unwind(AssertUnwindSafe(rollback));
}
}
}
impl Drop for CommitApplyGuard {
fn drop(&mut self) {
if !self.finished {
self.rollback_best_effort();
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct CommitGuard {
pub(crate) marker: CommitMarker,
}
impl CommitGuard {
fn clear(self) {
let _ = self;
with_commit_store_infallible(CommitStore::clear_infallible);
}
}
pub(crate) fn begin_commit(marker: CommitMarker) -> Result<CommitGuard, InternalError> {
with_commit_store(|store| {
if store.load()?.is_some() {
return Err(InternalError::store_invariant(
"commit marker already present before begin",
));
}
store.set(&marker)?;
Ok(CommitGuard { marker })
})
}
pub(crate) fn begin_commit_with_migration_state(
marker: CommitMarker,
migration_state_bytes: Vec<u8>,
) -> Result<CommitGuard, InternalError> {
with_commit_store(|store| {
if store.load()?.is_some() {
return Err(InternalError::store_invariant(
"commit marker already present before begin",
));
}
store.set_with_migration_state(&marker, migration_state_bytes)?;
Ok(CommitGuard { marker })
})
}
pub(crate) fn finish_commit(
mut guard: CommitGuard,
apply: impl FnOnce(&mut CommitGuard) -> Result<(), InternalError>,
) -> Result<(), InternalError> {
let result = apply(&mut guard);
let commit_id = guard.marker.id;
if result.is_ok() {
guard.clear();
assert!(
with_commit_store_infallible(|store| store.is_empty()),
"commit marker must be cleared after successful finish_commit (commit_id={commit_id:?})"
);
} else {
assert!(
with_commit_store_infallible(|store| !store.is_empty()),
"commit marker must remain persisted after failed finish_commit (commit_id={commit_id:?})"
);
}
result
}