use crate::{
db::{
Db,
commit::{
memory::configure_commit_memory_id,
rebuild_secondary_indexes_from_rows, replay_commit_marker_row_ops,
store::{commit_marker_present_fast, with_commit_store},
},
diagnostics::integrity_report_after_recovery,
},
error::{ErrorOrigin, InternalError},
traits::CanisterKind,
};
use std::sync::OnceLock;
static RECOVERED: OnceLock<()> = OnceLock::new();
pub(crate) fn ensure_recovered<C: CanisterKind>(db: &Db<C>) -> Result<(), InternalError> {
configure_commit_memory_id(C::COMMIT_MEMORY_ID)
.map_err(|err| err.with_origin(ErrorOrigin::Recovery))?;
if RECOVERED.get().is_none() {
return perform_recovery(db);
}
if commit_marker_present_fast().map_err(|err| err.with_origin(ErrorOrigin::Recovery))? {
return perform_recovery(db);
}
Ok(())
}
fn perform_recovery<C: CanisterKind>(db: &Db<C>) -> Result<(), InternalError> {
let marker = with_commit_store(|store| store.load())
.map_err(|err| err.with_origin(ErrorOrigin::Recovery))?;
let had_marker = marker.is_some();
if let Some(marker) = marker {
replay_commit_marker_row_ops(db, &marker.row_ops)
.map_err(|err| err.with_origin(ErrorOrigin::Recovery))?;
}
rebuild_secondary_indexes_from_rows(db)
.map_err(|err| err.with_origin(ErrorOrigin::Recovery))?;
validate_recovery_integrity(db).map_err(|err| err.with_origin(ErrorOrigin::Recovery))?;
if had_marker {
with_commit_store(|store| {
store.clear_infallible();
Ok(())
})
.map_err(|err| err.with_origin(ErrorOrigin::Recovery))?;
}
db.mark_all_registered_index_stores_ready();
let _ = RECOVERED.set(());
Ok(())
}
fn validate_recovery_integrity<C: CanisterKind>(db: &Db<C>) -> Result<(), InternalError> {
if !db.has_runtime_hooks() {
return Ok(());
}
let report = integrity_report_after_recovery(db)?;
let totals = report.totals();
if totals.missing_index_entries() > 0
|| totals.divergent_index_entries() > 0
|| totals.orphan_index_references() > 0
{
return Err(InternalError::recovery_integrity_validation_failed(
totals.missing_index_entries(),
totals.divergent_index_entries(),
totals.orphan_index_references(),
));
}
Ok(())
}