use crate::raft::OxirsNodeId;
use crate::serialization::{MessageSerializer, SerializationConfig};
use crate::storage::{RaftState, SnapshotMetadata, WalEntry, WalOperation};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use lmdb::{Database, DatabaseFlags, Environment, Transaction, WriteFlags};
use memmap2::{Mmap, MmapOptions};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::fs;
use tokio::sync::{Mutex, RwLock};
use tokio::time::Instant;
use scirs2_core::profiling::Profiler;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageConfig {
pub data_dir: PathBuf,
pub max_wal_size: u64,
pub wal_sync_mode: WalSyncMode,
pub enable_fsync: bool,
pub enable_mmap: bool,
pub checkpoint_interval: u64,
pub enable_compaction: bool,
pub enable_compression: bool,
pub cache_size: usize,
pub enable_encryption: bool,
pub encryption_key: Option<[u8; 32]>,
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
data_dir: PathBuf::from("./data"),
max_wal_size: 64 * 1024 * 1024, wal_sync_mode: WalSyncMode::Sync,
enable_fsync: true,
enable_mmap: true,
checkpoint_interval: 300, enable_compaction: true,
enable_compression: true,
cache_size: 128 * 1024 * 1024, enable_encryption: false,
encryption_key: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WalSyncMode {
NoSync,
Sync,
FullSync,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StorageStats {
pub total_operations: u64,
pub bytes_written: u64,
pub bytes_read: u64,
pub checkpoints_created: u64,
pub corruption_detections: u64,
pub recovery_operations: u64,
pub avg_write_latency: Duration,
pub avg_read_latency: Duration,
pub cache_hit_ratio: f64,
pub current_wal_size: u64,
pub compactions_performed: u64,
}
#[derive(Debug)]
struct AtomicFile {
file: File,
temp_path: PathBuf,
final_path: PathBuf,
sync_mode: WalSyncMode,
}
impl AtomicFile {
fn create(final_path: PathBuf, sync_mode: WalSyncMode) -> Result<Self> {
let temp_path = final_path.with_extension("tmp");
let file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&temp_path)?;
Ok(Self {
file,
temp_path,
final_path,
sync_mode,
})
}
fn write_all(&mut self, data: &[u8]) -> Result<()> {
self.file.write_all(data)?;
if matches!(self.sync_mode, WalSyncMode::FullSync) {
self.file.sync_all()?;
}
Ok(())
}
fn commit(self) -> Result<()> {
if matches!(self.sync_mode, WalSyncMode::Sync | WalSyncMode::FullSync) {
self.file.sync_all()?;
}
std::fs::rename(&self.temp_path, &self.final_path)?;
Ok(())
}
}
impl Drop for AtomicFile {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.temp_path);
}
}
struct WriteAheadLog {
config: StorageConfig,
current_sequence: AtomicU64,
wal_file: Arc<Mutex<File>>,
serializer: Arc<Mutex<MessageSerializer>>,
}
impl WriteAheadLog {
async fn new(config: StorageConfig) -> Result<Self> {
let wal_path = config.data_dir.join("wal.log");
if let Some(parent) = wal_path.parent() {
fs::create_dir_all(parent).await?;
}
let wal_file = OpenOptions::new()
.create(true)
.append(true)
.open(&wal_path)?;
let current_sequence = AtomicU64::new(0);
if wal_path.exists() && wal_file.metadata()?.len() > 0 {
}
let serializer_config = SerializationConfig {
compression: if config.enable_compression {
crate::serialization::CompressionAlgorithm::Lz4
} else {
crate::serialization::CompressionAlgorithm::None
},
..Default::default()
};
let serializer = Arc::new(Mutex::new(MessageSerializer::with_config(
serializer_config,
)));
Ok(Self {
config,
current_sequence,
wal_file: Arc::new(Mutex::new(wal_file)),
serializer,
})
}
async fn append(&self, operation: WalOperation) -> Result<u64> {
let sequence = self.current_sequence.fetch_add(1, Ordering::SeqCst);
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
let entry = WalEntry {
sequence,
timestamp,
operation: operation.clone(),
checksum: String::new(), };
let mut serializer = self.serializer.lock().await;
let serialized = serializer.serialize(&entry)?;
drop(serializer);
let checksum = format!("{:x}", Sha256::digest(&serialized.payload));
let entry_with_checksum = WalEntry { checksum, ..entry };
let mut serializer = self.serializer.lock().await;
let final_serialized = serializer.serialize(&entry_with_checksum)?;
drop(serializer);
let mut wal_file = self.wal_file.lock().await;
let entry_size = (final_serialized.payload.len() as u32).to_le_bytes();
wal_file.write_all(&entry_size)?;
wal_file.write_all(&final_serialized.payload)?;
match self.config.wal_sync_mode {
WalSyncMode::NoSync => {}
WalSyncMode::Sync | WalSyncMode::FullSync => {
wal_file.sync_all()?;
}
}
Ok(sequence)
}
async fn recover(&self) -> Result<Vec<WalEntry>> {
let wal_path = self.config.data_dir.join("wal.log");
if !wal_path.exists() {
return Ok(Vec::new());
}
let mut file = File::open(&wal_path)?;
let mut recovered_entries = Vec::new();
let mut buffer = Vec::new();
let mut serializer = self.serializer.lock().await;
loop {
let mut size_bytes = [0u8; 4];
match file.read_exact(&mut size_bytes) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
}
let entry_size = u32::from_le_bytes(size_bytes) as usize;
buffer.clear();
buffer.resize(entry_size, 0);
file.read_exact(&mut buffer)?;
let serialized_message = crate::serialization::SerializedMessage {
schema_version: Default::default(),
compression: crate::serialization::CompressionAlgorithm::None,
format: crate::serialization::SerializationFormat::MessagePack,
payload: buffer.clone(),
checksum: None,
original_size: buffer.len(),
compression_ratio: 1.0,
};
match serializer.deserialize::<WalEntry>(&serialized_message) {
Ok(entry) => {
let computed_checksum = format!("{:x}", Sha256::digest(&buffer));
if entry.checksum == computed_checksum {
let sequence = entry.sequence;
recovered_entries.push(entry);
self.current_sequence.store(sequence + 1, Ordering::SeqCst);
} else {
tracing::warn!(
"WAL entry {} has invalid checksum, stopping recovery",
entry.sequence
);
break;
}
}
Err(_) => {
tracing::warn!("Failed to deserialize WAL entry, stopping recovery");
break;
}
}
}
Ok(recovered_entries)
}
async fn truncate(&self, up_to_sequence: u64) -> Result<()> {
tracing::info!("WAL truncation requested up to sequence {}", up_to_sequence);
Ok(())
}
async fn size(&self) -> Result<u64> {
let wal_path = self.config.data_dir.join("wal.log");
if wal_path.exists() {
Ok(fs::metadata(&wal_path).await?.len())
} else {
Ok(0)
}
}
}
#[derive(Debug)]
struct MmapStorage {
_file: File,
#[allow(dead_code)]
mmap: Mmap,
}
impl MmapStorage {
fn new(file_path: &Path) -> Result<Self> {
let file = File::open(file_path)?;
let mmap = unsafe { MmapOptions::new().map(&file)? };
Ok(Self { _file: file, mmap })
}
#[allow(dead_code)]
fn read(&self, offset: usize, length: usize) -> Result<&[u8]> {
if offset + length > self.mmap.len() {
return Err(anyhow!("Read beyond end of memory-mapped file"));
}
Ok(&self.mmap[offset..offset + length])
}
#[allow(dead_code)]
fn len(&self) -> usize {
self.mmap.len()
}
}
#[derive(Debug)]
struct LruCache<K, V> {
map: HashMap<K, V>,
max_size: usize,
access_order: Vec<K>,
}
impl<K: Clone + std::hash::Hash + Eq, V: Clone> LruCache<K, V> {
fn new(max_size: usize) -> Self {
Self {
map: HashMap::new(),
max_size,
access_order: Vec::new(),
}
}
fn get(&mut self, key: &K) -> Option<V> {
if let Some(value) = self.map.get(key) {
self.access_order.retain(|k| k != key);
self.access_order.push(key.clone());
Some(value.clone())
} else {
None
}
}
fn put(&mut self, key: K, value: V) {
if self.map.contains_key(&key) {
self.map.insert(key.clone(), value);
self.access_order.retain(|k| k != &key);
self.access_order.push(key);
} else {
if self.map.len() >= self.max_size {
if let Some(lru_key) = self.access_order.first().cloned() {
self.map.remove(&lru_key);
self.access_order.remove(0);
}
}
self.map.insert(key.clone(), value);
self.access_order.push(key);
}
}
#[allow(dead_code)]
fn len(&self) -> usize {
self.map.len()
}
fn clear(&mut self) {
self.map.clear();
self.access_order.clear();
}
}
pub struct AdvancedStorageBackend {
config: StorageConfig,
wal: WriteAheadLog,
state_cache: Arc<RwLock<LruCache<String, RaftState>>>,
snapshot_cache: Arc<RwLock<LruCache<String, SnapshotMetadata>>>,
mmap_storage: Arc<RwLock<Option<MmapStorage>>>,
stats: Arc<RwLock<StorageStats>>,
serializer: Arc<Mutex<MessageSerializer>>,
environment: Arc<Mutex<Environment>>,
state_db: Database,
snapshot_db: Database,
profiler: Arc<Profiler>,
}
impl AdvancedStorageBackend {
pub async fn new(config: StorageConfig) -> Result<Self> {
fs::create_dir_all(&config.data_dir).await?;
let env_path = config.data_dir.join("lmdb");
fs::create_dir_all(&env_path).await?;
let environment = Environment::new()
.set_max_readers(1024)
.set_max_dbs(16)
.set_map_size(config.cache_size)
.open(&env_path)?;
let state_db = environment.create_db(Some("raft_state"), DatabaseFlags::empty())?;
let snapshot_db = environment.create_db(Some("snapshots"), DatabaseFlags::empty())?;
let wal = WriteAheadLog::new(config.clone()).await?;
let cache_entries = config.cache_size / 1024; let state_cache = Arc::new(RwLock::new(LruCache::new(cache_entries)));
let snapshot_cache = Arc::new(RwLock::new(LruCache::new(cache_entries / 10)));
let serializer_config = SerializationConfig {
compression: if config.enable_compression {
crate::serialization::CompressionAlgorithm::Lz4
} else {
crate::serialization::CompressionAlgorithm::None
},
..Default::default()
};
let serializer = Arc::new(Mutex::new(MessageSerializer::with_config(
serializer_config,
)));
let mut backend = Self {
config,
wal,
state_cache,
snapshot_cache,
mmap_storage: Arc::new(RwLock::new(None)),
stats: Arc::new(RwLock::new(StorageStats::default())),
serializer,
environment: Arc::new(Mutex::new(environment)),
state_db,
snapshot_db,
profiler: Arc::new(Profiler::new()),
};
backend.recover().await?;
Ok(backend)
}
async fn recover(&mut self) -> Result<()> {
let start_time = Instant::now();
let entries = self.wal.recover().await?;
tracing::info!("Recovering from {} WAL entries", entries.len());
for entry in entries {
match entry.operation {
WalOperation::WriteRaftState(state) => {
self.store_raft_state_internal(state).await?;
}
WalOperation::WriteAppState(_app_state) => {
tracing::debug!("Recovered application state");
}
WalOperation::CreateSnapshot(_metadata) => {
tracing::debug!("Recovered snapshot metadata");
}
WalOperation::TruncateLog(index) => {
tracing::debug!("Recovered log truncation to index {}", index);
}
WalOperation::Commit(sequence) => {
tracing::debug!("Recovered commit for sequence {}", sequence);
}
}
}
let mut stats = self.stats.write().await;
stats.recovery_operations += 1;
stats.avg_read_latency = start_time.elapsed();
tracing::info!("Recovery completed in {:?}", start_time.elapsed());
Ok(())
}
pub async fn store_raft_state(&self, state: RaftState) -> Result<()> {
let start_time = Instant::now();
self.wal
.append(WalOperation::WriteRaftState(state.clone()))
.await?;
self.store_raft_state_internal(state.clone()).await?;
let mut cache = self.state_cache.write().await;
cache.put("current".to_string(), state);
let mut stats = self.stats.write().await;
stats.total_operations += 1;
stats.avg_write_latency = (stats.avg_write_latency * (stats.total_operations - 1) as u32
+ start_time.elapsed())
/ stats.total_operations as u32;
Ok(())
}
async fn store_raft_state_internal(&self, state: RaftState) -> Result<()> {
let serialized = {
let mut serializer = self.serializer.lock().await;
serializer.serialize(&state)?
};
{
let env = self.environment.lock().await;
let mut txn = env.begin_rw_txn()?;
txn.put(
self.state_db,
&b"current",
&serialized.payload,
WriteFlags::empty(),
)?;
txn.commit()?;
}
let mut stats = self.stats.write().await;
stats.bytes_written += serialized.payload.len() as u64;
Ok(())
}
pub async fn load_raft_state(&self) -> Result<Option<RaftState>> {
let start_time = Instant::now();
{
let mut cache = self.state_cache.write().await;
if let Some(state) = cache.get(&"current".to_string()) {
let mut stats = self.stats.write().await;
stats.cache_hit_ratio = (stats.cache_hit_ratio * 0.9) + 0.1; return Ok(Some(state));
}
}
let data_result = {
let env = self.environment.lock().await;
let txn = env.begin_ro_txn()?;
match txn.get(self.state_db, &b"current") {
Ok(data) => Ok(Some(data.to_vec())),
Err(lmdb::Error::NotFound) => Ok(None),
Err(e) => Err(e),
}
};
let result = match data_result? {
Some(data) => {
let data_len = data.len();
let serialized_message = crate::serialization::SerializedMessage {
schema_version: Default::default(),
compression: crate::serialization::CompressionAlgorithm::None,
format: crate::serialization::SerializationFormat::MessagePack,
payload: data,
checksum: None,
original_size: data_len,
compression_ratio: 1.0,
};
let state = {
let mut serializer = self.serializer.lock().await;
serializer.deserialize::<RaftState>(&serialized_message)?
};
{
let mut cache = self.state_cache.write().await;
cache.put("current".to_string(), state.clone());
}
Some(state)
}
None => None,
};
let mut stats = self.stats.write().await;
stats.total_operations += 1;
stats.bytes_read += result
.as_ref()
.map(|_| std::mem::size_of::<RaftState>())
.unwrap_or(0) as u64;
stats.avg_read_latency = (stats.avg_read_latency * (stats.total_operations - 1) as u32
+ start_time.elapsed())
/ stats.total_operations as u32;
stats.cache_hit_ratio = (stats.cache_hit_ratio * 0.9) + 0.0;
Ok(result)
}
fn parallel_compress(&self, data: &[u8]) -> Vec<u8> {
const PARALLEL_THRESHOLD: usize = 1024 * 1024; const CHUNK_SIZE: usize = 256 * 1024;
if data.len() < PARALLEL_THRESHOLD {
return oxiarc_lz4::compress(data).unwrap_or_else(|_| data.to_vec());
}
use rayon::prelude::*;
let chunks: Vec<&[u8]> = data.chunks(CHUNK_SIZE).collect();
let num_chunks = chunks.len();
let compressed_chunks: Vec<Vec<u8>> = chunks
.par_iter()
.map(|chunk| oxiarc_lz4::compress_block(chunk).unwrap_or_else(|_| chunk.to_vec()))
.collect();
let chunk_sizes: Vec<u32> = compressed_chunks.iter().map(|c| c.len() as u32).collect();
let total_compressed_size: usize = compressed_chunks.iter().map(|c| c.len()).sum();
let metadata_size = 8 + 4 + (num_chunks * 4);
let mut result = Vec::with_capacity(metadata_size + total_compressed_size);
result.extend_from_slice(&(data.len() as u64).to_le_bytes());
result.extend_from_slice(&(num_chunks as u32).to_le_bytes());
for size in &chunk_sizes {
result.extend_from_slice(&size.to_le_bytes());
}
for chunk in compressed_chunks {
result.extend_from_slice(&chunk);
}
result
}
fn parallel_decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 12 {
return oxiarc_lz4::decompress(data, 100 * 1024 * 1024)
.map_err(|e| anyhow!("Decompression failed: {}", e));
}
let original_size = u64::from_le_bytes(
data[0..8]
.try_into()
.expect("slice should be exactly 8 bytes"),
) as usize;
let num_chunks = u32::from_le_bytes(
data[8..12]
.try_into()
.expect("slice should be exactly 4 bytes"),
) as usize;
if num_chunks == 0 || num_chunks > 100000 {
return oxiarc_lz4::decompress(data, 100 * 1024 * 1024)
.map_err(|e| anyhow!("Decompression failed: {}", e));
}
let metadata_size = 12 + (num_chunks * 4);
if data.len() < metadata_size {
return oxiarc_lz4::decompress(data, 100 * 1024 * 1024)
.map_err(|e| anyhow!("Decompression failed: {}", e));
}
let mut chunk_sizes = Vec::with_capacity(num_chunks);
for i in 0..num_chunks {
let offset = 12 + (i * 4);
let size = u32::from_le_bytes(
data[offset..offset + 4]
.try_into()
.expect("slice should be exactly 4 bytes"),
) as usize;
chunk_sizes.push(size);
}
let mut chunks = Vec::with_capacity(num_chunks);
let mut offset = metadata_size;
for &size in &chunk_sizes {
if offset + size > data.len() {
return Err(anyhow!("Invalid chunk data"));
}
chunks.push(&data[offset..offset + size]);
offset += size;
}
use rayon::prelude::*;
let decompressed_chunks: Result<Vec<Vec<u8>>, _> = chunks
.par_iter()
.map(|chunk| {
oxiarc_lz4::decompress_block(chunk, original_size)
.map_err(|e| anyhow!("Chunk decompression failed: {}", e))
})
.collect();
let decompressed_chunks = decompressed_chunks?;
let mut result = Vec::with_capacity(original_size);
for chunk in decompressed_chunks {
result.extend_from_slice(&chunk);
}
Ok(result)
}
pub fn get_profiling_report(&self) -> String {
self.profiler.get_report()
}
pub async fn create_snapshot(
&self,
last_included_index: u64,
last_included_term: u64,
configuration: Vec<OxirsNodeId>,
data: &[u8],
) -> Result<SnapshotMetadata> {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
let (final_data, compressed) = if self.config.enable_compression {
let compressed = self.parallel_compress(data);
(compressed, true)
} else {
(data.to_vec(), false)
};
let checksum = format!("{:x}", Sha256::digest(&final_data));
let metadata = SnapshotMetadata {
last_included_index,
last_included_term,
configuration,
timestamp,
size: final_data.len() as u64,
checksum,
};
let snapshot_path = self.config.data_dir.join(format!(
"snapshot-{last_included_index}-{last_included_term}.dat"
));
let mut atomic_file = AtomicFile::create(snapshot_path, self.config.wal_sync_mode)?;
atomic_file.write_all(&final_data)?;
atomic_file.commit()?;
let key = format!("snapshot-{last_included_index}-{last_included_term}");
{
let serialized_metadata = {
let mut serializer = self.serializer.lock().await;
serializer.serialize(&metadata)?
};
let env = self.environment.lock().await;
let mut txn = env.begin_rw_txn()?;
txn.put(
self.snapshot_db,
&key.as_bytes(),
&serialized_metadata.payload,
WriteFlags::empty(),
)?;
txn.commit()?;
}
{
let mut cache = self.snapshot_cache.write().await;
cache.put(key.clone(), metadata.clone());
}
self.wal
.append(WalOperation::CreateSnapshot(metadata.clone()))
.await?;
let mut stats = self.stats.write().await;
stats.checkpoints_created += 1;
stats.bytes_written += final_data.len() as u64;
tracing::info!(
"Created snapshot {} (compressed: {}, size: {} bytes)",
metadata.last_included_index,
compressed,
final_data.len()
);
Ok(metadata)
}
pub async fn load_snapshot(
&self,
last_included_index: u64,
last_included_term: u64,
) -> Result<Option<Vec<u8>>> {
let key = format!("snapshot-{last_included_index}-{last_included_term}");
let metadata = {
let mut cache = self.snapshot_cache.write().await;
cache.get(&key)
};
let metadata = if let Some(meta) = metadata {
meta
} else {
let data_result = {
let env = self.environment.lock().await;
let txn = env.begin_ro_txn()?;
match txn.get(self.snapshot_db, &key.as_bytes()) {
Ok(data) => Ok(Some(data.to_vec())),
Err(lmdb::Error::NotFound) => Ok(None),
Err(e) => Err(e),
}
};
match data_result? {
Some(data) => {
let data_len = data.len();
let serialized_message = crate::serialization::SerializedMessage {
schema_version: Default::default(),
compression: crate::serialization::CompressionAlgorithm::None,
format: crate::serialization::SerializationFormat::MessagePack,
payload: data,
checksum: None,
original_size: data_len,
compression_ratio: 1.0,
};
let metadata = {
let mut serializer = self.serializer.lock().await;
serializer.deserialize::<SnapshotMetadata>(&serialized_message)?
};
{
let mut cache = self.snapshot_cache.write().await;
cache.put(key.clone(), metadata.clone());
}
metadata
}
None => return Ok(None),
}
};
let snapshot_path = self.config.data_dir.join(format!(
"snapshot-{last_included_index}-{last_included_term}.dat"
));
if !snapshot_path.exists() {
return Ok(None);
}
let data = fs::read(&snapshot_path).await?;
let computed_checksum = format!("{:x}", Sha256::digest(&data));
if computed_checksum != metadata.checksum {
let mut stats = self.stats.write().await;
stats.corruption_detections += 1;
return Err(anyhow!("Snapshot checksum verification failed"));
}
let final_data = match self.parallel_decompress(&data) {
Ok(decompressed) => decompressed,
Err(_) => data, };
let mut stats = self.stats.write().await;
stats.bytes_read += final_data.len() as u64;
Ok(Some(final_data))
}
pub async fn stats(&self) -> StorageStats {
let mut stats = self.stats.read().await.clone();
stats.current_wal_size = self.wal.size().await.unwrap_or(0);
stats
}
pub async fn clear_caches(&self) {
let mut state_cache = self.state_cache.write().await;
state_cache.clear();
let mut snapshot_cache = self.snapshot_cache.write().await;
snapshot_cache.clear();
tracing::info!("Cleared all storage caches");
}
pub async fn compact(&self) -> Result<()> {
if !self.config.enable_compaction {
return Ok(());
}
let start_time = Instant::now();
tracing::info!("Storage compaction completed in {:?}", start_time.elapsed());
let mut stats = self.stats.write().await;
stats.compactions_performed += 1;
Ok(())
}
pub async fn enable_mmap(&self, file_path: &Path) -> Result<()> {
if !self.config.enable_mmap {
return Ok(());
}
let mmap_storage = MmapStorage::new(file_path)?;
let mut mmap = self.mmap_storage.write().await;
*mmap = Some(mmap_storage);
tracing::info!("Enabled memory-mapped storage for {:?}", file_path);
Ok(())
}
pub async fn maintenance(&self) -> Result<()> {
let wal_size = self.wal.size().await?;
if wal_size > self.config.max_wal_size {
tracing::info!(
"WAL size {} exceeds limit {}, truncating",
wal_size,
self.config.max_wal_size
);
self.wal.truncate(0).await?;
}
if self.config.enable_compaction {
self.compact().await?;
}
Ok(())
}
}
#[async_trait]
pub trait StorageBackend: Send + Sync {
async fn store_raft_state(&self, state: RaftState) -> Result<()>;
async fn load_raft_state(&self) -> Result<Option<RaftState>>;
async fn create_snapshot(
&self,
last_included_index: u64,
last_included_term: u64,
configuration: Vec<OxirsNodeId>,
data: &[u8],
) -> Result<SnapshotMetadata>;
async fn load_snapshot(
&self,
last_included_index: u64,
last_included_term: u64,
) -> Result<Option<Vec<u8>>>;
async fn stats(&self) -> StorageStats;
}
#[async_trait]
impl StorageBackend for AdvancedStorageBackend {
async fn store_raft_state(&self, state: RaftState) -> Result<()> {
self.store_raft_state(state).await
}
async fn load_raft_state(&self) -> Result<Option<RaftState>> {
self.load_raft_state().await
}
async fn create_snapshot(
&self,
last_included_index: u64,
last_included_term: u64,
configuration: Vec<OxirsNodeId>,
data: &[u8],
) -> Result<SnapshotMetadata> {
self.create_snapshot(last_included_index, last_included_term, configuration, data)
.await
}
async fn load_snapshot(
&self,
last_included_index: u64,
last_included_term: u64,
) -> Result<Option<Vec<u8>>> {
self.load_snapshot(last_included_index, last_included_term)
.await
}
async fn stats(&self) -> StorageStats {
self.stats().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn create_test_storage() -> (AdvancedStorageBackend, TempDir) {
let temp_dir = TempDir::new().unwrap();
let config = StorageConfig {
data_dir: temp_dir.path().to_path_buf(),
..Default::default()
};
let storage = AdvancedStorageBackend::new(config).await.unwrap();
(storage, temp_dir)
}
#[tokio::test]
async fn test_raft_state_storage() {
let (storage, _temp_dir) = create_test_storage().await;
let state = RaftState {
current_term: 1,
voted_for: Some(1),
log: vec![],
commit_index: 0,
last_applied: 0,
};
storage.store_raft_state(state.clone()).await.unwrap();
let loaded_state = storage.load_raft_state().await.unwrap().unwrap();
assert_eq!(state.current_term, loaded_state.current_term);
assert_eq!(state.voted_for, loaded_state.voted_for);
}
#[tokio::test]
async fn test_snapshot_creation_and_loading() {
let (storage, _temp_dir) = create_test_storage().await;
let data = b"test snapshot data";
let metadata = storage
.create_snapshot(10, 1, vec![1, 2, 3], data)
.await
.unwrap();
assert_eq!(metadata.last_included_index, 10);
assert_eq!(metadata.last_included_term, 1);
assert_eq!(metadata.configuration, vec![1, 2, 3]);
let loaded_data = storage.load_snapshot(10, 1).await.unwrap().unwrap();
assert_eq!(&loaded_data, data);
}
#[tokio::test]
async fn test_cache_functionality() {
let (storage, _temp_dir) = create_test_storage().await;
let state = RaftState {
current_term: 1,
voted_for: Some(1),
log: vec![],
commit_index: 0,
last_applied: 0,
};
storage.store_raft_state(state.clone()).await.unwrap();
for _ in 0..5 {
let _loaded_state = storage.load_raft_state().await.unwrap().unwrap();
}
let stats = storage.stats().await;
assert!(stats.cache_hit_ratio > 0.0);
}
#[tokio::test]
async fn test_corruption_detection() {
let (storage, temp_dir) = create_test_storage().await;
let data = b"test data";
let _metadata = storage.create_snapshot(1, 1, vec![1], data).await.unwrap();
let snapshot_path = temp_dir.path().join("snapshot-1-1.dat");
let mut corrupted_data = fs::read(&snapshot_path).await.unwrap();
corrupted_data[0] ^= 0xFF; fs::write(&snapshot_path, corrupted_data).await.unwrap();
let result = storage.load_snapshot(1, 1).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("checksum verification failed"));
}
#[tokio::test]
async fn test_wal_recovery() {
let temp_dir = TempDir::new().unwrap();
let config = StorageConfig {
data_dir: temp_dir.path().to_path_buf(),
..Default::default()
};
let state = RaftState {
current_term: 5,
voted_for: Some(2),
log: vec![],
commit_index: 0,
last_applied: 0,
};
{
let storage = AdvancedStorageBackend::new(config.clone()).await.unwrap();
storage.store_raft_state(state.clone()).await.unwrap();
}
let storage = AdvancedStorageBackend::new(config).await.unwrap();
let recovered_state = storage.load_raft_state().await.unwrap().unwrap();
assert_eq!(state.current_term, recovered_state.current_term);
assert_eq!(state.voted_for, recovered_state.voted_for);
}
}