use entidb_core::{Database, EntityId};
use entidb_storage::StorageBackend;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CrashPoint {
BeforeWalWrite,
DuringWalWrite,
AfterWalWriteBeforeCommit,
AfterCommitBeforeFlush,
AfterCommitAndFlush,
DuringCompaction,
DuringCheckpoint,
}
#[derive(Debug, Clone)]
pub struct CrashRecoveryResult {
pub passed: bool,
pub description: String,
pub expected_entities: usize,
pub actual_entities: usize,
pub error: Option<String>,
}
impl CrashRecoveryResult {
pub fn pass(description: &str, entities: usize) -> Self {
Self {
passed: true,
description: description.to_string(),
expected_entities: entities,
actual_entities: entities,
error: None,
}
}
pub fn fail(description: &str, expected: usize, actual: usize, error: &str) -> Self {
Self {
passed: false,
description: description.to_string(),
expected_entities: expected,
actual_entities: actual,
error: Some(error.to_string()),
}
}
}
pub struct CrashableBackend {
inner: Box<dyn StorageBackend>,
crash_after_bytes: AtomicUsize,
bytes_written: AtomicUsize,
crashed: AtomicBool,
fail_on_flush: AtomicBool,
}
impl CrashableBackend {
pub fn new(inner: Box<dyn StorageBackend>) -> Self {
Self {
inner,
crash_after_bytes: AtomicUsize::new(usize::MAX),
bytes_written: AtomicUsize::new(0),
crashed: AtomicBool::new(false),
fail_on_flush: AtomicBool::new(false),
}
}
pub fn crash_after(&self, bytes: usize) {
self.crash_after_bytes.store(bytes, Ordering::SeqCst);
}
pub fn set_fail_on_flush(&self, fail: bool) {
self.fail_on_flush.store(fail, Ordering::SeqCst);
}
pub fn reset(&self) {
self.crash_after_bytes.store(usize::MAX, Ordering::SeqCst);
self.bytes_written.store(0, Ordering::SeqCst);
self.crashed.store(false, Ordering::SeqCst);
self.fail_on_flush.store(false, Ordering::SeqCst);
}
pub fn has_crashed(&self) -> bool {
self.crashed.load(Ordering::SeqCst)
}
}
impl StorageBackend for CrashableBackend {
fn read_at(&self, offset: u64, len: usize) -> entidb_storage::StorageResult<Vec<u8>> {
self.inner.read_at(offset, len)
}
fn append(&mut self, bytes: &[u8]) -> entidb_storage::StorageResult<u64> {
let current = self.bytes_written.fetch_add(bytes.len(), Ordering::SeqCst);
let crash_threshold = self.crash_after_bytes.load(Ordering::SeqCst);
if current >= crash_threshold {
self.crashed.store(true, Ordering::SeqCst);
return Err(entidb_storage::StorageError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
"simulated crash during write",
)));
}
if current + bytes.len() > crash_threshold {
self.crashed.store(true, Ordering::SeqCst);
let partial_len = crash_threshold - current;
if partial_len > 0 {
let _ = self.inner.append(&bytes[..partial_len]);
}
return Err(entidb_storage::StorageError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
"simulated crash during partial write",
)));
}
self.inner.append(bytes)
}
fn flush(&mut self) -> entidb_storage::StorageResult<()> {
if self.fail_on_flush.load(Ordering::SeqCst) {
self.crashed.store(true, Ordering::SeqCst);
return Err(entidb_storage::StorageError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
"simulated crash during flush",
)));
}
self.inner.flush()
}
fn size(&self) -> entidb_storage::StorageResult<u64> {
self.inner.size()
}
fn truncate(&mut self, new_size: u64) -> entidb_storage::StorageResult<()> {
self.inner.truncate(new_size)
}
fn sync(&mut self) -> entidb_storage::StorageResult<()> {
if self.fail_on_flush.load(Ordering::SeqCst) {
self.crashed.store(true, Ordering::SeqCst);
return Err(entidb_storage::StorageError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
"simulated crash during sync",
)));
}
self.inner.sync()
}
}
pub struct CrashRecoveryHarness {
pub db_path: PathBuf,
pub results: Vec<CrashRecoveryResult>,
}
impl CrashRecoveryHarness {
pub fn new(db_path: impl AsRef<Path>) -> Self {
Self {
db_path: db_path.as_ref().to_path_buf(),
results: Vec::new(),
}
}
pub fn with_temp_dir() -> std::io::Result<Self> {
let unique_id = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let temp_dir = std::env::temp_dir().join("entidb_crash_test").join(format!(
"test_{}_{}",
std::process::id(),
unique_id
));
std::fs::create_dir_all(&temp_dir)?;
Ok(Self::new(temp_dir))
}
pub fn cleanup(&self) -> std::io::Result<()> {
if self.db_path.exists() {
std::fs::remove_dir_all(&self.db_path)?;
}
Ok(())
}
fn open_fresh_db(&self) -> Result<Database, entidb_core::CoreError> {
let _ = std::fs::remove_dir_all(&self.db_path);
std::fs::create_dir_all(&self.db_path)?;
Database::open(&self.db_path)
}
fn reopen_db(&self) -> Result<Database, entidb_core::CoreError> {
Database::open(&self.db_path)
}
pub fn test_committed_data_survives(&mut self) -> CrashRecoveryResult {
let result = (|| {
let db = self.open_fresh_db()?;
let collection = db.collection("test");
let mut ids = Vec::new();
for i in 0..10 {
let id = EntityId::new();
ids.push(id);
db.transaction(|txn| {
txn.put(collection, id, vec![i as u8; 100])?;
Ok(())
})?;
}
db.checkpoint()?;
db.close()?;
drop(db);
let db = self.reopen_db()?;
let collection = db.collection("test");
let mut found = 0;
for (i, id) in ids.iter().enumerate() {
if let Some(data) = db.get(collection, *id)? {
if data == vec![i as u8; 100] {
found += 1;
}
}
}
db.close()?;
drop(db);
if found == 10 {
Ok(CrashRecoveryResult::pass(
"Committed data survives crash",
10,
))
} else {
Ok(CrashRecoveryResult::fail(
"Committed data survives crash",
10,
found,
"Some entities were lost",
))
}
})();
let result = result.unwrap_or_else(|e: entidb_core::CoreError| {
CrashRecoveryResult::fail("Committed data survives crash", 10, 0, &e.to_string())
});
self.results.push(result.clone());
result
}
pub fn test_uncommitted_data_discarded(&mut self) -> CrashRecoveryResult {
let result = (|| {
let db = self.open_fresh_db()?;
let collection = db.collection("test");
let committed_id = EntityId::new();
db.transaction(|txn| {
txn.put(collection, committed_id, b"committed".to_vec())?;
Ok(())
})?;
db.checkpoint()?;
db.close()?;
drop(db);
let db = self.reopen_db()?;
let collection = db.collection("test");
let committed_exists = db.get(collection, committed_id)?.is_some();
db.close()?;
drop(db);
if committed_exists {
Ok(CrashRecoveryResult::pass(
"Uncommitted data discarded, committed data preserved",
1,
))
} else {
Ok(CrashRecoveryResult::fail(
"Uncommitted data discarded, committed data preserved",
1,
0,
"Committed data was lost",
))
}
})();
let result = result.unwrap_or_else(|e: entidb_core::CoreError| {
CrashRecoveryResult::fail("Uncommitted data discarded", 1, 0, &e.to_string())
});
self.results.push(result.clone());
result
}
pub fn test_crash_after_compaction(&mut self) -> CrashRecoveryResult {
let result = (|| {
let db = self.open_fresh_db()?;
let collection = db.collection("test");
let id = EntityId::new();
for i in 0..5 {
db.transaction(|txn| {
txn.put(collection, id, vec![i as u8; 50])?;
Ok(())
})?;
}
let _stats = db.compact(false)?;
db.checkpoint()?;
db.close()?;
drop(db);
let db = self.reopen_db()?;
let collection = db.collection("test");
let data = db.get(collection, id)?;
db.close()?;
drop(db);
match data {
Some(d) if d == vec![4u8; 50] => Ok(CrashRecoveryResult::pass(
"Latest version preserved after compaction",
1,
)),
Some(_) => Ok(CrashRecoveryResult::fail(
"Latest version preserved after compaction",
1,
1,
"Wrong version of entity found",
)),
None => Ok(CrashRecoveryResult::fail(
"Latest version preserved after compaction",
1,
0,
"Entity lost after compaction",
)),
}
})();
let result = result.unwrap_or_else(|e: entidb_core::CoreError| {
CrashRecoveryResult::fail("Crash after compaction", 1, 0, &e.to_string())
});
self.results.push(result.clone());
result
}
pub fn test_wal_replay(&mut self) -> CrashRecoveryResult {
let result = (|| {
let db = self.open_fresh_db()?;
let collection = db.collection("test");
let mut ids = Vec::new();
for i in 0..5 {
let id = EntityId::new();
ids.push(id);
db.transaction(|txn| {
txn.put(collection, id, vec![i as u8; 100])?;
Ok(())
})?;
}
db.close()?;
drop(db);
let db = self.reopen_db()?;
let collection = db.collection("test");
let mut found = 0;
for (i, id) in ids.iter().enumerate() {
if let Some(data) = db.get(collection, *id)? {
if data == vec![i as u8; 100] {
found += 1;
}
}
}
db.close()?;
drop(db);
if found == 5 {
Ok(CrashRecoveryResult::pass(
"WAL replay recovers committed data",
5,
))
} else {
Ok(CrashRecoveryResult::fail(
"WAL replay recovers committed data",
5,
found,
"Some entities were not recovered from WAL",
))
}
})();
let result = result.unwrap_or_else(|e: entidb_core::CoreError| {
CrashRecoveryResult::fail("WAL replay", 5, 0, &e.to_string())
});
self.results.push(result.clone());
result
}
pub fn test_mixed_recovery(&mut self) -> CrashRecoveryResult {
let result = (|| {
let db = self.open_fresh_db()?;
let collection = db.collection("test");
let segment_ids: Vec<EntityId> = (0..3).map(|_| EntityId::new()).collect();
for (i, id) in segment_ids.iter().enumerate() {
db.transaction(|txn| {
txn.put(collection, *id, format!("segment_{}", i).into_bytes())?;
Ok(())
})?;
}
db.checkpoint()?;
let wal_ids: Vec<EntityId> = (0..3).map(|_| EntityId::new()).collect();
for (i, id) in wal_ids.iter().enumerate() {
db.transaction(|txn| {
txn.put(collection, *id, format!("wal_{}", i).into_bytes())?;
Ok(())
})?;
}
db.close()?;
drop(db);
let db = self.reopen_db()?;
let collection = db.collection("test");
let mut segment_found = 0;
for (i, id) in segment_ids.iter().enumerate() {
if let Some(data) = db.get(collection, *id)? {
if data == format!("segment_{}", i).into_bytes() {
segment_found += 1;
}
}
}
let mut wal_found = 0;
for (i, id) in wal_ids.iter().enumerate() {
if let Some(data) = db.get(collection, *id)? {
if data == format!("wal_{}", i).into_bytes() {
wal_found += 1;
}
}
}
db.close()?;
drop(db);
let total_found = segment_found + wal_found;
if total_found == 6 {
Ok(CrashRecoveryResult::pass(
"Mixed segment and WAL recovery",
6,
))
} else {
Ok(CrashRecoveryResult::fail(
"Mixed segment and WAL recovery",
6,
total_found,
&format!(
"Found {} from segments, {} from WAL",
segment_found, wal_found
),
))
}
})();
let result = result.unwrap_or_else(|e: entidb_core::CoreError| {
CrashRecoveryResult::fail("Mixed recovery", 6, 0, &e.to_string())
});
self.results.push(result.clone());
result
}
pub fn test_delete_survives_crash(&mut self) -> CrashRecoveryResult {
let result = (|| {
let db = self.open_fresh_db()?;
let collection = db.collection("test");
let id = EntityId::new();
db.transaction(|txn| {
txn.put(collection, id, b"test data".to_vec())?;
Ok(())
})?;
db.checkpoint()?;
db.transaction(|txn| {
txn.delete(collection, id)?;
Ok(())
})?;
db.checkpoint()?;
db.close()?;
drop(db);
let db = self.reopen_db()?;
let collection = db.collection("test");
let exists = db.get(collection, id)?.is_some();
db.close()?;
drop(db);
if !exists {
Ok(CrashRecoveryResult::pass("Delete survives crash", 0))
} else {
Ok(CrashRecoveryResult::fail(
"Delete survives crash",
0,
1,
"Deleted entity still exists",
))
}
})();
let result = result.unwrap_or_else(|e: entidb_core::CoreError| {
CrashRecoveryResult::fail("Delete survives crash", 0, 0, &e.to_string())
});
self.results.push(result.clone());
result
}
pub fn run_all_tests(&mut self) -> Vec<CrashRecoveryResult> {
self.results.clear();
self.test_committed_data_survives();
self.test_uncommitted_data_discarded();
self.test_crash_after_compaction();
self.test_wal_replay();
self.test_mixed_recovery();
self.test_delete_survives_crash();
self.results.clone()
}
pub fn summary(&self) -> String {
let passed = self.results.iter().filter(|r| r.passed).count();
let total = self.results.len();
let mut summary = format!(
"\n=== Crash Recovery Test Summary ===\n\
Passed: {}/{}\n\n",
passed, total
);
for result in &self.results {
let status = if result.passed { "✓" } else { "✗" };
summary.push_str(&format!(
"{} {}\n Expected: {} entities, Actual: {} entities\n",
status, result.description, result.expected_entities, result.actual_entities
));
if let Some(ref error) = result.error {
summary.push_str(&format!(" Error: {}\n", error));
}
}
summary
}
pub fn all_passed(&self) -> bool {
self.results.iter().all(|r| r.passed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use entidb_storage::InMemoryBackend;
#[test]
fn test_crashable_backend_normal_operation() {
let inner = Box::new(InMemoryBackend::new());
let mut backend = CrashableBackend::new(inner);
let data = b"test data";
let offset = backend.append(data).unwrap();
backend.flush().unwrap();
let read = backend.read_at(offset, data.len()).unwrap();
assert_eq!(read, data);
}
#[test]
fn test_crashable_backend_crash_on_write() {
let inner = Box::new(InMemoryBackend::new());
let mut backend = CrashableBackend::new(inner);
backend.crash_after(10);
let _ = backend.append(&[1u8; 5]);
let result = backend.append(&[2u8; 10]);
assert!(result.is_err());
assert!(backend.has_crashed());
}
#[test]
fn test_crashable_backend_crash_on_flush() {
let inner = Box::new(InMemoryBackend::new());
let mut backend = CrashableBackend::new(inner);
backend.set_fail_on_flush(true);
let result = backend.flush();
assert!(result.is_err());
assert!(backend.has_crashed());
}
#[test]
fn test_crash_recovery_harness() {
let mut harness = CrashRecoveryHarness::with_temp_dir().unwrap();
let result = harness.test_committed_data_survives();
println!("{:?}", result);
harness.cleanup().unwrap();
}
#[test]
fn test_all_crash_recovery_scenarios() {
let mut harness = CrashRecoveryHarness::with_temp_dir().unwrap();
let _results = harness.run_all_tests();
println!("{}", harness.summary());
harness.cleanup().unwrap();
assert!(harness.all_passed(), "Some crash recovery tests failed");
}
}