use super::config::StorageConfig;
#[cfg(test)]
use super::error::StorageError;
use super::recovery::*;
use super::stats::StorageStats;
use super::types::*;
use crate::network::LogEntry;
use crate::raft::{OxirsNodeId, RdfApp, RdfCommand};
use crate::shard::ShardId;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use oxirs_core::model::Triple;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::fs;
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
#[async_trait]
pub trait StorageBackend: Send + Sync {
async fn create_shard(&self, shard_id: ShardId) -> Result<()>;
async fn delete_shard(&self, shard_id: ShardId) -> Result<()>;
async fn insert_triple_to_shard(&self, shard_id: ShardId, triple: Triple) -> Result<()>;
async fn delete_triple_from_shard(&self, shard_id: ShardId, triple: &Triple) -> Result<()>;
async fn query_shard(
&self,
shard_id: ShardId,
subject: Option<&str>,
predicate: Option<&str>,
object: Option<&str>,
) -> Result<Vec<Triple>>;
async fn get_shard_size(&self, shard_id: ShardId) -> Result<u64>;
async fn get_shard_triple_count(&self, shard_id: ShardId) -> Result<usize>;
async fn export_shard(&self, shard_id: ShardId) -> Result<Vec<Triple>>;
async fn import_shard(&self, shard_id: ShardId, triples: Vec<Triple>) -> Result<()>;
async fn get_shard_triples(&self, shard_id: ShardId) -> Result<Vec<Triple>>;
async fn insert_triples_to_shard(&self, shard_id: ShardId, triples: Vec<Triple>) -> Result<()>;
async fn mark_shard_for_deletion(&self, shard_id: ShardId) -> Result<()>;
}
pub mod mock {
use super::*;
use oxirs_core::RdfTerm;
use std::collections::HashMap;
#[derive(Debug, Default)]
pub struct MockStorageBackend {
shards: Arc<RwLock<HashMap<ShardId, Vec<Triple>>>>,
}
impl MockStorageBackend {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl StorageBackend for MockStorageBackend {
async fn create_shard(&self, shard_id: ShardId) -> Result<()> {
self.shards.write().await.insert(shard_id, Vec::new());
Ok(())
}
async fn delete_shard(&self, shard_id: ShardId) -> Result<()> {
self.shards.write().await.remove(&shard_id);
Ok(())
}
async fn insert_triple_to_shard(&self, shard_id: ShardId, triple: Triple) -> Result<()> {
let mut shards = self.shards.write().await;
if let Some(shard) = shards.get_mut(&shard_id) {
shard.push(triple);
} else {
shards.insert(shard_id, vec![triple]);
}
Ok(())
}
async fn delete_triple_from_shard(&self, shard_id: ShardId, triple: &Triple) -> Result<()> {
let mut shards = self.shards.write().await;
if let Some(shard) = shards.get_mut(&shard_id) {
shard.retain(|t| t != triple);
}
Ok(())
}
async fn query_shard(
&self,
shard_id: ShardId,
subject: Option<&str>,
predicate: Option<&str>,
object: Option<&str>,
) -> Result<Vec<Triple>> {
let shards = self.shards.read().await;
if let Some(shard) = shards.get(&shard_id) {
let results: Vec<Triple> = shard
.iter()
.filter(|triple| {
let subject_match = subject.map_or(true, |s| {
if let oxirs_core::model::Subject::NamedNode(named_node) =
triple.subject()
{
named_node.as_str() == s
} else {
triple.subject().to_string() == s
}
});
let predicate_match =
predicate.map_or(true, |p| triple.predicate().as_str() == p);
let object_match = object.map_or(true, |o| {
if let oxirs_core::Object::NamedNode(named_node) = triple.object() {
named_node.as_str() == o
} else {
triple.object().to_string() == o
}
});
subject_match && predicate_match && object_match
})
.cloned()
.collect();
Ok(results)
} else {
Ok(Vec::new())
}
}
async fn get_shard_size(&self, shard_id: ShardId) -> Result<u64> {
let shards = self.shards.read().await;
if let Some(shard) = shards.get(&shard_id) {
Ok((shard.len() * 100) as u64)
} else {
Ok(0)
}
}
async fn get_shard_triple_count(&self, shard_id: ShardId) -> Result<usize> {
let shards = self.shards.read().await;
Ok(shards.get(&shard_id).map_or(0, |s| s.len()))
}
async fn export_shard(&self, shard_id: ShardId) -> Result<Vec<Triple>> {
let shards = self.shards.read().await;
Ok(shards.get(&shard_id).cloned().unwrap_or_default())
}
async fn import_shard(&self, shard_id: ShardId, triples: Vec<Triple>) -> Result<()> {
self.shards.write().await.insert(shard_id, triples);
Ok(())
}
async fn get_shard_triples(&self, shard_id: ShardId) -> Result<Vec<Triple>> {
let shards = self.shards.read().await;
Ok(shards.get(&shard_id).cloned().unwrap_or_default())
}
async fn insert_triples_to_shard(
&self,
shard_id: ShardId,
triples: Vec<Triple>,
) -> Result<()> {
let mut shards = self.shards.write().await;
if let Some(shard) = shards.get_mut(&shard_id) {
shard.extend(triples);
} else {
shards.insert(shard_id, triples);
}
Ok(())
}
async fn mark_shard_for_deletion(&self, shard_id: ShardId) -> Result<()> {
self.shards.write().await.remove(&shard_id);
Ok(())
}
}
}
pub struct PersistentStorage {
data_dir: PathBuf,
node_id: OxirsNodeId,
raft_state: Arc<RwLock<RaftState>>,
app_state: Arc<RwLock<RdfApp>>,
config: StorageConfig,
wal_sequence: Arc<RwLock<u64>>,
wal_writer: Arc<RwLock<Option<BufWriter<File>>>>,
}
impl PersistentStorage {
pub async fn new(node_id: OxirsNodeId, config: StorageConfig) -> Result<Self> {
let data_dir = PathBuf::from(&config.data_dir).join(format!("node-{node_id}"));
if !data_dir.exists() {
std::fs::create_dir_all(&data_dir)?;
}
let storage = Self {
data_dir,
node_id,
raft_state: Arc::new(RwLock::new(RaftState::default())),
app_state: Arc::new(RwLock::new(RdfApp::default())),
config,
wal_sequence: Arc::new(RwLock::new(0)),
wal_writer: Arc::new(RwLock::new(None)),
};
if storage.config.enable_wal {
storage.init_wal().await?;
}
if storage.config.enable_crash_recovery {
storage.recover_from_crash().await?;
}
storage.load_state().await?;
Ok(storage)
}
async fn init_wal(&self) -> Result<()> {
let wal_path = self.data_dir.join("wal.log");
let mut sequence = 0;
if wal_path.exists() {
sequence = self.get_last_wal_sequence(&wal_path).await?;
}
let wal_file = OpenOptions::new()
.create(true)
.append(true)
.open(&wal_path)?;
let writer = BufWriter::new(wal_file);
*self.wal_writer.write().await = Some(writer);
*self.wal_sequence.write().await = sequence;
tracing::info!(
"Initialized WAL for node {} at sequence {}",
self.node_id,
sequence
);
Ok(())
}
async fn get_last_wal_sequence(&self, wal_path: &Path) -> Result<u64> {
let file = File::open(wal_path)?;
let mut reader = BufReader::new(file);
let mut last_sequence = 0;
loop {
let mut length_bytes = [0u8; 8];
match reader.read_exact(&mut length_bytes) {
Ok(_) => {
let length = u64::from_le_bytes(length_bytes);
if length > 100 * 1024 * 1024 {
tracing::warn!("WAL entry length too large: {}, skipping", length);
break;
}
let mut entry_bytes = vec![0u8; length as usize];
match reader.read_exact(&mut entry_bytes) {
Ok(_) => {
if let Ok(entry) = oxicode::serde::decode_from_slice(
&entry_bytes,
oxicode::config::standard(),
)
.map(|(v, _): (WalEntry, _)| v)
{
last_sequence = entry.sequence;
}
}
Err(_) => break, }
}
Err(_) => break, }
}
Ok(last_sequence)
}
#[allow(dead_code)]
async fn recover_from_wal_internal(&self) -> Result<()> {
let wal_path = self.data_dir.join("wal.log");
if !wal_path.exists() {
return Ok(());
}
tracing::info!("Starting crash recovery for node {}", self.node_id);
let file = File::open(&wal_path)?;
let mut reader = BufReader::new(file);
let mut uncommitted_ops: Vec<WalEntry> = Vec::new();
let mut last_commit_sequence = 0;
loop {
let mut length_bytes = [0u8; 8];
match reader.read_exact(&mut length_bytes) {
Ok(_) => {
let length = u64::from_le_bytes(length_bytes);
let mut entry_bytes = vec![0u8; length as usize];
reader.read_exact(&mut entry_bytes)?;
if let Ok(entry) =
oxicode::serde::decode_from_slice(&entry_bytes, oxicode::config::standard())
.map(|(v, _): (WalEntry, _)| v)
{
if self.verify_wal_entry_checksum(&entry)? {
match &entry.operation {
WalOperation::Commit(seq) => {
last_commit_sequence = *seq;
for op_entry in &uncommitted_ops {
if op_entry.sequence <= last_commit_sequence {
self.apply_wal_operation(&op_entry.operation).await?;
}
}
uncommitted_ops.retain(|op| op.sequence > last_commit_sequence);
}
_ => {
uncommitted_ops.push(entry);
}
}
} else {
tracing::warn!(
"Corrupted WAL entry detected at sequence {}",
entry.sequence
);
}
}
}
Err(_) => break, }
}
tracing::info!(
"Crash recovery completed for node {}. Last commit: {}, {} uncommitted operations",
self.node_id,
last_commit_sequence,
uncommitted_ops.len()
);
Ok(())
}
fn verify_wal_entry_checksum(&self, entry: &WalEntry) -> Result<bool> {
let op_bytes =
oxicode::serde::encode_to_vec(&entry.operation, oxicode::config::standard())?;
let mut hasher = Sha256::new();
hasher.update(&op_bytes);
hasher.update(entry.sequence.to_le_bytes());
hasher.update(entry.timestamp.to_le_bytes());
let computed_checksum = format!("{:x}", hasher.finalize());
Ok(computed_checksum == entry.checksum)
}
#[allow(dead_code)]
async fn apply_wal_operation(&self, operation: &WalOperation) -> Result<()> {
match operation {
WalOperation::WriteRaftState(state) => {
*self.raft_state.write().await = state.clone();
}
WalOperation::WriteAppState(app_state) => {
*self.app_state.write().await = app_state.clone();
}
WalOperation::TruncateLog(from_index) => {
let mut state = self.raft_state.write().await;
state.log.retain(|entry| entry.index < *from_index);
}
WalOperation::CreateSnapshot(_metadata) => {
}
WalOperation::Commit(_) => {
}
}
Ok(())
}
pub async fn get_current_term(&self) -> u64 {
self.raft_state.read().await.current_term
}
pub async fn set_current_term(&self, term: u64) -> Result<()> {
{
let mut state = self.raft_state.write().await;
state.current_term = term;
state.voted_for = None; }
self.save_state().await
}
pub async fn get_voted_for(&self) -> Option<OxirsNodeId> {
self.raft_state.read().await.voted_for
}
pub async fn set_voted_for(&self, candidate_id: Option<OxirsNodeId>) -> Result<()> {
{
let mut state = self.raft_state.write().await;
state.voted_for = candidate_id;
}
self.save_state().await
}
pub async fn append_entries(&self, entries: Vec<LogEntry>) -> Result<()> {
{
let mut state = self.raft_state.write().await;
state.log.extend(entries);
}
self.save_state().await
}
pub async fn get_log_entries(&self, start: u64, end: u64) -> Vec<LogEntry> {
let state = self.raft_state.read().await;
state
.log
.iter()
.filter(|entry| entry.index >= start && entry.index < end)
.cloned()
.collect()
}
pub async fn get_log_entry(&self, index: u64) -> Option<LogEntry> {
let state = self.raft_state.read().await;
state.log.iter().find(|entry| entry.index == index).cloned()
}
pub async fn get_last_log_index(&self) -> u64 {
let state = self.raft_state.read().await;
state.log.last().map(|entry| entry.index).unwrap_or(0)
}
pub async fn get_last_log_term(&self) -> u64 {
let state = self.raft_state.read().await;
state.log.last().map(|entry| entry.term).unwrap_or(0)
}
pub async fn truncate_log(&self, from_index: u64) -> Result<()> {
{
let mut state = self.raft_state.write().await;
state.log.retain(|entry| entry.index < from_index);
}
self.save_state().await
}
pub async fn get_commit_index(&self) -> u64 {
self.raft_state.read().await.commit_index
}
pub async fn set_commit_index(&self, index: u64) -> Result<()> {
{
let mut state = self.raft_state.write().await;
state.commit_index = index;
}
self.save_state().await
}
pub async fn get_last_applied(&self) -> u64 {
self.raft_state.read().await.last_applied
}
pub async fn set_last_applied(&self, index: u64) -> Result<()> {
{
let mut state = self.raft_state.write().await;
state.last_applied = index;
}
self.save_state().await
}
pub async fn apply_command(&self, command: &RdfCommand) -> Result<()> {
let mut app_state = self.app_state.write().await;
app_state.apply_command(command);
self.save_app_state().await
}
pub async fn get_app_state(&self) -> RdfApp {
self.app_state.read().await.clone()
}
pub async fn create_snapshot(&self) -> Result<SnapshotMetadata> {
let raft_state = self.raft_state.read().await;
let app_state = self.app_state.read().await;
let last_log_entry = raft_state.log.last();
let snapshot_path = self.data_dir.join("snapshot.json");
let snapshot_data = serde_json::to_vec(&*app_state)?;
fs::write(&snapshot_path, &snapshot_data)?;
let mut hasher = Sha256::new();
hasher.update(&snapshot_data);
let checksum = format!("{:x}", hasher.finalize());
let metadata = SnapshotMetadata {
last_included_index: last_log_entry.map(|e| e.index).unwrap_or(0),
last_included_term: last_log_entry.map(|e| e.term).unwrap_or(0),
configuration: vec![self.node_id], timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
size: snapshot_data.len() as u64,
checksum,
};
let metadata_path = self.data_dir.join("snapshot_metadata.json");
let metadata_data = serde_json::to_vec(&metadata)?;
fs::write(&metadata_path, &metadata_data)?;
tracing::info!(
"Created snapshot for node {} at index {} with {} bytes",
self.node_id,
metadata.last_included_index,
metadata.size
);
Ok(metadata)
}
pub async fn load_snapshot(&self) -> Result<Option<(SnapshotMetadata, RdfApp)>> {
let snapshot_path = self.data_dir.join("snapshot.json");
let metadata_path = self.data_dir.join("snapshot_metadata.json");
if !snapshot_path.exists() || !metadata_path.exists() {
return Ok(None);
}
let metadata_data = fs::read(&metadata_path)?;
let metadata: SnapshotMetadata = serde_json::from_slice(&metadata_data)?;
let snapshot_data = fs::read(&snapshot_path)?;
let app_state: RdfApp = serde_json::from_slice(&snapshot_data)?;
Ok(Some((metadata, app_state)))
}
pub async fn compact_log(&self, until_index: u64) -> Result<()> {
{
let mut state = self.raft_state.write().await;
state.log.retain(|entry| entry.index > until_index);
}
self.save_state().await
}
pub async fn needs_compaction(&self) -> bool {
let state = self.raft_state.read().await;
state.log.len() > self.config.max_log_entries
}
async fn save_state(&self) -> Result<()> {
let state = self.raft_state.read().await;
if self.config.enable_wal {
self.write_wal_entry(WalOperation::WriteRaftState(state.clone()))
.await?;
}
self.atomic_write_with_checksum("raft_state.dat", &*state)
.await?;
if self.config.enable_wal {
let sequence = *self.wal_sequence.read().await;
self.write_wal_entry(WalOperation::Commit(sequence)).await?;
}
Ok(())
}
async fn write_wal_entry(&self, operation: WalOperation) -> Result<()> {
if !self.config.enable_wal {
return Ok(());
}
let mut wal_sequence = self.wal_sequence.write().await;
*wal_sequence += 1;
let sequence = *wal_sequence;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
let op_bytes = oxicode::serde::encode_to_vec(&operation, oxicode::config::standard())?;
let mut hasher = Sha256::new();
hasher.update(&op_bytes);
hasher.update(sequence.to_le_bytes());
hasher.update(timestamp.to_le_bytes());
let checksum = format!("{:x}", hasher.finalize());
let wal_entry = WalEntry {
sequence,
timestamp,
operation,
checksum,
};
let entry_bytes = oxicode::serde::encode_to_vec(&wal_entry, oxicode::config::standard())?;
let length = entry_bytes.len() as u64;
if let Some(ref mut writer) = self.wal_writer.write().await.as_mut() {
writer.write_all(&length.to_le_bytes())?;
writer.write_all(&entry_bytes)?;
writer.flush()?;
if self.config.sync_writes {
writer.get_mut().sync_all()?;
}
}
Ok(())
}
async fn atomic_write_with_checksum<T>(&self, filename: &str, data: &T) -> Result<()>
where
T: Serialize,
{
let path = self.data_dir.join(filename);
let temp_path = self.data_dir.join(format!("{filename}.tmp"));
let checksummed_data = if self.config.enable_corruption_detection {
ChecksummedData::new(data)?
} else {
ChecksummedData {
data,
checksum: String::new(),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
}
};
let serialized =
oxicode::serde::encode_to_vec(&checksummed_data, oxicode::config::standard())?;
{
let temp_file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&temp_path)?;
let mut writer = BufWriter::new(temp_file);
writer.write_all(&serialized)?;
writer.flush()?;
if self.config.sync_writes {
writer.get_mut().sync_all()?;
}
}
std::fs::rename(&temp_path, &path)?;
Ok(())
}
async fn load_state(&self) -> Result<()> {
let binary_path = self.data_dir.join("raft_state.dat");
if binary_path.exists() {
match self.load_with_checksum::<RaftState>(&binary_path).await {
Ok(state) => {
*self.raft_state.write().await = state;
tracing::info!("Loaded Raft state (binary) for node {}", self.node_id);
}
Err(e) => {
tracing::error!("Failed to load binary Raft state: {}", e);
if self.config.enable_corruption_detection {
return Err(anyhow!("Corrupted Raft state file"));
}
}
}
} else {
let json_path = self.data_dir.join("raft_state.json");
if json_path.exists() {
let data = std::fs::read(&json_path)?;
let state: RaftState = serde_json::from_slice(&data)?;
*self.raft_state.write().await = state;
tracing::info!("Loaded Raft state (legacy JSON) for node {}", self.node_id);
}
}
let app_binary_path = self.data_dir.join("app_state.dat");
if app_binary_path.exists() {
match self.load_with_checksum::<RdfApp>(&app_binary_path).await {
Ok(app_state) => {
*self.app_state.write().await = app_state;
tracing::info!(
"Loaded application state (binary) for node {}",
self.node_id
);
}
Err(e) => {
tracing::error!("Failed to load binary application state: {}", e);
if self.config.enable_corruption_detection {
return Err(anyhow!("Corrupted application state file"));
}
}
}
} else {
let app_json_path = self.data_dir.join("app_state.json");
if app_json_path.exists() {
let data = std::fs::read(&app_json_path)?;
let app_state: RdfApp = serde_json::from_slice(&data)?;
*self.app_state.write().await = app_state;
tracing::info!(
"Loaded application state (legacy JSON) for node {}",
self.node_id
);
}
}
Ok(())
}
async fn load_with_checksum<T>(&self, path: &Path) -> Result<T>
where
T: for<'de> Deserialize<'de> + Serialize,
{
let data = std::fs::read(path)?;
let (checksummed_data, _): (ChecksummedData<T>, _) =
oxicode::serde::decode_from_slice(&data, oxicode::config::standard())?;
if self.config.enable_corruption_detection && !checksummed_data.verify()? {
return Err(anyhow!("Checksum verification failed for {:?}", path));
}
Ok(checksummed_data.data)
}
async fn save_app_state(&self) -> Result<()> {
let app_state = self.app_state.read().await;
if self.config.enable_wal {
self.write_wal_entry(WalOperation::WriteAppState(app_state.clone()))
.await?;
}
self.atomic_write_with_checksum("app_state.dat", &*app_state)
.await?;
if self.config.enable_wal {
let sequence = *self.wal_sequence.read().await;
self.write_wal_entry(WalOperation::Commit(sequence)).await?;
}
Ok(())
}
pub async fn get_stats(&self) -> StorageStats {
let raft_state = self.raft_state.read().await;
let app_state = self.app_state.read().await;
let data_dir_size = self.calculate_directory_size(&self.data_dir).unwrap_or(0);
StorageStats {
node_id: self.node_id,
data_dir: self.data_dir.clone(),
log_entries: raft_state.log.len(),
current_term: raft_state.current_term,
commit_index: raft_state.commit_index,
last_applied: raft_state.last_applied,
triple_count: app_state.len(),
disk_usage_bytes: data_dir_size,
}
}
#[allow(clippy::only_used_in_recursion)]
fn calculate_directory_size(&self, dir: &Path) -> Result<u64> {
let mut size = 0;
if dir.is_dir() {
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.is_file() {
size += fs::metadata(&path)?.len();
} else if path.is_dir() {
size += self.calculate_directory_size(&path)?;
}
}
}
Ok(size)
}
pub async fn backup(&self) -> Result<PathBuf> {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
let backup_dir = self
.data_dir
.parent()
.expect("data_dir should have a parent directory")
.join(format!("backup-{}-{}", self.node_id, timestamp));
fs::create_dir_all(&backup_dir)?;
for entry in fs::read_dir(&self.data_dir)? {
let entry = entry?;
let source = entry.path();
let dest = backup_dir.join(entry.file_name());
fs::copy(&source, &dest)?;
}
tracing::info!("Created backup at {:?}", backup_dir);
Ok(backup_dir)
}
pub async fn rotate_wal(&self) -> Result<()> {
if !self.config.enable_wal {
return Ok(());
}
let wal_path = self.data_dir.join("wal.log");
if !wal_path.exists() {
return Ok(());
}
let metadata = std::fs::metadata(&wal_path)?;
if metadata.len() < 100 * 1024 * 1024 {
return Ok(());
}
*self.wal_writer.write().await = None;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
let archive_path = self.data_dir.join(format!("wal-{timestamp}.log"));
std::fs::rename(&wal_path, &archive_path)?;
self.init_wal().await?;
tracing::info!(
"Rotated WAL for node {}, archived to {:?}",
self.node_id,
archive_path
);
Ok(())
}
pub async fn compact_wal(&self) -> Result<()> {
if !self.config.enable_wal {
return Ok(());
}
let wal_path = self.data_dir.join("wal.log");
if !wal_path.exists() {
return Ok(());
}
let file = File::open(&wal_path)?;
let mut reader = BufReader::new(file);
let mut entries = Vec::new();
let mut last_commit_sequence = 0;
loop {
let mut length_bytes = [0u8; 8];
match reader.read_exact(&mut length_bytes) {
Ok(_) => {
let length = u64::from_le_bytes(length_bytes);
let mut entry_bytes = vec![0u8; length as usize];
reader.read_exact(&mut entry_bytes)?;
if let Ok(entry) =
oxicode::serde::decode_from_slice(&entry_bytes, oxicode::config::standard())
.map(|(v, _): (WalEntry, _)| v)
{
if let WalOperation::Commit(seq) = &entry.operation {
last_commit_sequence = *seq;
}
entries.push(entry);
}
}
Err(_) => break,
}
}
let total_entries = entries.len();
let uncommitted: Vec<_> = entries
.into_iter()
.filter(|entry| entry.sequence > last_commit_sequence)
.collect();
let temp_wal_path = self.data_dir.join("wal.log.tmp");
{
let temp_file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&temp_wal_path)?;
let mut writer = BufWriter::new(temp_file);
for entry in &uncommitted {
let entry_bytes =
oxicode::serde::encode_to_vec(entry, oxicode::config::standard())?;
let length = entry_bytes.len() as u64;
writer.write_all(&length.to_le_bytes())?;
writer.write_all(&entry_bytes)?;
}
writer.flush()?;
if self.config.sync_writes {
writer.get_mut().sync_all()?;
}
}
std::fs::rename(&temp_wal_path, &wal_path)?;
*self.wal_writer.write().await = None;
self.init_wal().await?;
tracing::info!(
"Compacted WAL for node {}, removed {} committed entries, kept {} uncommitted",
self.node_id,
total_entries - uncommitted.len(),
uncommitted.len()
);
Ok(())
}
pub async fn verify_integrity(&self) -> Result<bool> {
let mut all_valid = true;
let raft_path = self.data_dir.join("raft_state.dat");
if raft_path.exists() {
match self.load_with_checksum::<RaftState>(&raft_path).await {
Ok(_) => tracing::info!("Raft state integrity verified"),
Err(e) => {
tracing::error!("Raft state integrity check failed: {}", e);
all_valid = false;
}
}
}
let app_path = self.data_dir.join("app_state.dat");
if app_path.exists() {
match self.load_with_checksum::<RdfApp>(&app_path).await {
Ok(_) => tracing::info!("Application state integrity verified"),
Err(e) => {
tracing::error!("Application state integrity check failed: {}", e);
all_valid = false;
}
}
}
if self.config.enable_wal {
let wal_path = self.data_dir.join("wal.log");
if wal_path.exists() {
match self.verify_wal_integrity(&wal_path).await {
Ok(valid_entries) => {
tracing::info!("WAL integrity verified, {} valid entries", valid_entries)
}
Err(e) => {
tracing::error!("WAL integrity check failed: {}", e);
all_valid = false;
}
}
}
}
Ok(all_valid)
}
async fn verify_wal_integrity(&self, wal_path: &Path) -> Result<usize> {
let file = File::open(wal_path)?;
let mut reader = BufReader::new(file);
let mut valid_entries = 0;
loop {
let mut length_bytes = [0u8; 8];
match reader.read_exact(&mut length_bytes) {
Ok(_) => {
let length = u64::from_le_bytes(length_bytes);
let mut entry_bytes = vec![0u8; length as usize];
reader.read_exact(&mut entry_bytes)?;
if let Ok(entry) =
oxicode::serde::decode_from_slice(&entry_bytes, oxicode::config::standard())
.map(|(v, _): (WalEntry, _)| v)
{
if self.verify_wal_entry_checksum(&entry)? {
valid_entries += 1;
} else {
return Err(anyhow!(
"Invalid checksum for WAL entry at sequence {}",
entry.sequence
));
}
} else {
return Err(anyhow!("Failed to deserialize WAL entry"));
}
}
Err(_) => break, }
}
Ok(valid_entries)
}
pub async fn cleanup_old_backups(&self) -> Result<()> {
let parent_dir = self
.data_dir
.parent()
.expect("data_dir should have a parent directory");
let backup_prefix = format!("backup-{}-", self.node_id);
let mut backups = Vec::new();
for entry in fs::read_dir(parent_dir)? {
let entry = entry?;
let name = entry.file_name();
if let Some(name_str) = name.to_str() {
if name_str.starts_with(&backup_prefix) {
backups.push((entry.path(), name_str.to_string()));
}
}
}
backups.sort_by(|a, b| b.1.cmp(&a.1));
for (path, _) in backups.iter().skip(self.config.backup_retention) {
fs::remove_dir_all(path)?;
tracing::info!("Removed old backup: {:?}", path);
}
Ok(())
}
pub async fn recover_from_crash(&self) -> Result<RecoveryReport> {
if !self.config.enable_crash_recovery {
return Ok(RecoveryReport::new());
}
let mut report = RecoveryReport::new();
if self.config.enable_wal {
report.wal_recovered = self.recover_from_wal().await?;
}
if self.config.enable_corruption_detection {
let corruption_report = self.check_file_integrity().await?;
report.corrupted_files = corruption_report.corrupted_files;
report.recovered_files = corruption_report.recovered_files;
}
let log_consistency = self.verify_log_consistency().await?;
if !log_consistency.is_consistent {
report.log_inconsistencies = log_consistency.issues.len();
self.repair_log_inconsistencies(log_consistency.issues)
.await?;
}
let state_consistency = self.verify_state_consistency().await?;
report.state_machine_repaired = state_consistency.repaired;
tracing::info!("Crash recovery completed: {:?}", report);
Ok(report)
}
async fn recover_from_wal(&self) -> Result<bool> {
let wal_path = self.data_dir.join("wal.log");
if !wal_path.exists() {
return Ok(false);
}
tracing::info!("Recovering from write-ahead log");
let wal_data = fs::read(&wal_path)?;
if wal_data.is_empty() {
fs::remove_file(&wal_path)?;
return Ok(false);
}
if let Ok(operations) = serde_json::from_slice::<Vec<WalOperation>>(&wal_data) {
for operation in operations {
match operation {
WalOperation::WriteRaftState(state) => {
let state_json = serde_json::to_string(&state)?;
let state_path = self.data_dir.join("raft_state.json");
fs::write(&state_path, state_json)?;
}
WalOperation::WriteAppState(app_state) => {
let app_json = serde_json::to_string(&app_state)?;
let app_state_path = self.data_dir.join("app_state.json");
fs::write(&app_state_path, app_json)?;
}
WalOperation::CreateSnapshot(metadata) => {
let metadata_json = serde_json::to_string(&metadata)?;
let snapshot_path = self.data_dir.join("snapshot_metadata.json");
fs::write(&snapshot_path, metadata_json)?;
}
WalOperation::TruncateLog(_index) => {
}
WalOperation::Commit(_sequence) => {
}
}
}
fs::remove_file(&wal_path)?;
tracing::info!("Successfully recovered from WAL");
return Ok(true);
}
fs::remove_file(&wal_path)?;
tracing::warn!("WAL file was corrupted and removed");
Ok(false)
}
async fn check_file_integrity(&self) -> Result<CorruptionReport> {
let mut report = CorruptionReport::new();
let files_to_check = vec![
("raft_state.json", "raft_state.json.checksum"),
("app_state.json", "app_state.json.checksum"),
("snapshot.json", "snapshot.json.checksum"),
];
for (filename, checksum_filename) in files_to_check {
let file_path = self.data_dir.join(filename);
let checksum_path = self.data_dir.join(checksum_filename);
if file_path.exists() {
let integrity_ok = self
.verify_file_checksum(&file_path, &checksum_path)
.await?;
if !integrity_ok {
report.corrupted_files.push(filename.to_string());
if self.recover_corrupted_file(&file_path).await? {
report.recovered_files.push(filename.to_string());
}
}
}
}
Ok(report)
}
async fn verify_file_checksum(&self, file_path: &Path, checksum_path: &Path) -> Result<bool> {
if !checksum_path.exists() {
let checksum = self.calculate_file_checksum(file_path).await?;
fs::write(checksum_path, checksum)?;
return Ok(true);
}
let stored_checksum = fs::read_to_string(checksum_path)?;
let current_checksum = self.calculate_file_checksum(file_path).await?;
Ok(stored_checksum.trim() == current_checksum)
}
async fn calculate_file_checksum(&self, file_path: &Path) -> Result<String> {
use sha2::{Digest, Sha256};
let data = fs::read(file_path)?;
let mut hasher = Sha256::new();
hasher.update(data);
let result = hasher.finalize();
Ok(format!("{result:x}"))
}
async fn recover_corrupted_file(&self, file_path: &Path) -> Result<bool> {
let filename = file_path
.file_name()
.expect("file_path should have a file name")
.to_string_lossy();
let parent_dir = self
.data_dir
.parent()
.expect("data_dir should have a parent directory");
let backup_prefix = format!("backup-{}-", self.node_id);
let mut backups = Vec::new();
for entry in fs::read_dir(parent_dir)? {
let entry = entry?;
let name = entry.file_name();
if let Some(name_str) = name.to_str() {
if name_str.starts_with(&backup_prefix) {
let backup_file_path = entry.path().join(&*filename);
if backup_file_path.exists() {
backups.push((backup_file_path, name_str.to_string()));
}
}
}
}
if backups.is_empty() {
return Ok(false);
}
backups.sort_by(|a, b| b.1.cmp(&a.1));
let (backup_path, _) = &backups[0];
fs::copy(backup_path, file_path)?;
let checksum_path = file_path.with_extension(format!(
"{}.checksum",
file_path.extension().unwrap_or_default().to_string_lossy()
));
let checksum = self.calculate_file_checksum(file_path).await?;
fs::write(&checksum_path, checksum)?;
tracing::info!("Recovered corrupted file {} from backup", filename);
Ok(true)
}
async fn verify_log_consistency(&self) -> Result<LogConsistencyReport> {
let state = self.raft_state.read().await;
let mut report = LogConsistencyReport::new();
let mut expected_index = 1u64;
for entry in &state.log {
if entry.index != expected_index {
report.issues.push(LogInconsistency::IndexGap {
expected: expected_index,
found: entry.index,
});
}
expected_index = entry.index + 1;
}
let mut indices = std::collections::HashSet::new();
for entry in &state.log {
if !indices.insert(entry.index) {
report
.issues
.push(LogInconsistency::DuplicateIndex { index: entry.index });
}
}
if state.commit_index > state.log.last().map(|e| e.index).unwrap_or(0) {
report.issues.push(LogInconsistency::InvalidCommitIndex {
commit_index: state.commit_index,
last_log_index: state.log.last().map(|e| e.index).unwrap_or(0),
});
}
report.is_consistent = report.issues.is_empty();
Ok(report)
}
async fn repair_log_inconsistencies(&self, issues: Vec<LogInconsistency>) -> Result<()> {
let mut state = self.raft_state.write().await;
for issue in issues {
match issue {
LogInconsistency::IndexGap { expected, found } => {
tracing::warn!(
"Fixing log index gap: expected {}, found {}",
expected,
found
);
state
.log
.retain(|entry| entry.index < expected || entry.index >= found);
}
LogInconsistency::DuplicateIndex { index } => {
tracing::warn!("Removing duplicate log entry at index {}", index);
let mut seen = false;
state.log.retain(|entry| {
if entry.index == index {
if seen {
false
} else {
seen = true;
true
}
} else {
true
}
});
}
LogInconsistency::InvalidCommitIndex {
commit_index,
last_log_index,
} => {
tracing::warn!(
"Fixing invalid commit index: {} > {}",
commit_index,
last_log_index
);
state.commit_index = last_log_index;
}
}
}
drop(state);
self.save_state().await
}
async fn verify_state_consistency(&self) -> Result<StateConsistencyReport> {
let state = self.raft_state.read().await;
let app_state = self.app_state.read().await;
let mut report = StateConsistencyReport::new();
if state.last_applied < state.commit_index {
report.repaired = true;
tracing::info!(
"Applying unapplied committed entries: {} to {}",
state.last_applied + 1,
state.commit_index
);
for entry in &state.log {
if entry.index > state.last_applied && entry.index <= state.commit_index {
}
}
drop(app_state);
drop(state);
self.set_last_applied(self.raft_state.read().await.commit_index)
.await?;
}
Ok(report)
}
#[allow(dead_code)]
async fn write_to_wal(&self, operation: WalOperation) -> Result<()> {
if !self.config.enable_wal {
return Ok(());
}
let wal_path = self.data_dir.join("wal.log");
let mut operations = if wal_path.exists() {
let data = fs::read(&wal_path)?;
if data.is_empty() {
Vec::new()
} else {
serde_json::from_slice(&data).unwrap_or_default()
}
} else {
Vec::new()
};
operations.push(operation);
let data = serde_json::to_vec(&operations)?;
fs::write(&wal_path, data)?;
if self.config.sync_writes {
let file = fs::File::open(&wal_path)?;
file.sync_all()?;
}
Ok(())
}
#[allow(dead_code)]
async fn clear_wal(&self) -> Result<()> {
if !self.config.enable_wal {
return Ok(());
}
let wal_path = self.data_dir.join("wal.log");
if wal_path.exists() {
fs::remove_file(&wal_path)?;
}
Ok(())
}
#[allow(dead_code)]
async fn update_file_checksum(&self, file_path: &Path) -> Result<()> {
if !self.config.enable_corruption_detection {
return Ok(());
}
let checksum_path = file_path.with_extension(format!(
"{}.checksum",
file_path.extension().unwrap_or_default().to_string_lossy()
));
let checksum = self.calculate_file_checksum(file_path).await?;
fs::write(&checksum_path, checksum)?;
Ok(())
}
}
#[async_trait]
impl StorageBackend for PersistentStorage {
async fn create_shard(&self, shard_id: ShardId) -> Result<()> {
let mut app_state = self.app_state.write().await;
app_state.create_shard(shard_id);
self.save_app_state().await
}
async fn delete_shard(&self, shard_id: ShardId) -> Result<()> {
let mut app_state = self.app_state.write().await;
app_state.delete_shard(shard_id);
self.save_app_state().await
}
async fn insert_triple_to_shard(&self, shard_id: ShardId, triple: Triple) -> Result<()> {
let mut app_state = self.app_state.write().await;
app_state.insert_triple_to_shard(shard_id, triple);
self.save_app_state().await
}
async fn delete_triple_from_shard(&self, shard_id: ShardId, triple: &Triple) -> Result<()> {
let mut app_state = self.app_state.write().await;
app_state.delete_triple_from_shard(shard_id, triple);
self.save_app_state().await
}
async fn query_shard(
&self,
shard_id: ShardId,
subject: Option<&str>,
predicate: Option<&str>,
object: Option<&str>,
) -> Result<Vec<Triple>> {
let app_state = self.app_state.read().await;
Ok(app_state.query_shard(shard_id, subject, predicate, object))
}
async fn get_shard_size(&self, shard_id: ShardId) -> Result<u64> {
let app_state = self.app_state.read().await;
Ok(app_state.get_shard_size(shard_id))
}
async fn get_shard_triple_count(&self, shard_id: ShardId) -> Result<usize> {
let app_state = self.app_state.read().await;
Ok(app_state.get_shard_triple_count(shard_id))
}
async fn export_shard(&self, shard_id: ShardId) -> Result<Vec<Triple>> {
let app_state = self.app_state.read().await;
Ok(app_state.export_shard(shard_id))
}
async fn import_shard(&self, shard_id: ShardId, triples: Vec<Triple>) -> Result<()> {
let mut app_state = self.app_state.write().await;
app_state.import_shard(shard_id, triples);
self.save_app_state().await
}
async fn get_shard_triples(&self, shard_id: ShardId) -> Result<Vec<Triple>> {
let app_state = self.app_state.read().await;
Ok(app_state.get_shard_triples(shard_id))
}
async fn insert_triples_to_shard(&self, shard_id: ShardId, triples: Vec<Triple>) -> Result<()> {
let mut app_state = self.app_state.write().await;
app_state.insert_triples_to_shard(shard_id, triples);
self.save_app_state().await
}
async fn mark_shard_for_deletion(&self, shard_id: ShardId) -> Result<()> {
let mut app_state = self.app_state.write().await;
app_state.mark_shard_for_deletion(shard_id);
self.save_app_state().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::raft::RdfCommand;
use tempfile::TempDir;
async fn create_test_storage() -> (PersistentStorage, TempDir) {
let temp_dir = TempDir::new().unwrap();
let config = StorageConfig {
data_dir: temp_dir.path().to_string_lossy().to_string(),
sync_writes: false, max_log_entries: 100,
compress_snapshots: false,
backup_retention: 2,
enable_corruption_detection: false, enable_crash_recovery: false, enable_wal: false, };
let node_id = std::process::id() as u64
+ std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
let storage = PersistentStorage::new(node_id, config).await.unwrap();
(storage, temp_dir)
}
#[tokio::test]
async fn test_storage_creation() {
let (storage, _temp_dir) = create_test_storage().await;
assert!(storage.node_id > 0); assert_eq!(storage.get_current_term().await, 0);
assert_eq!(storage.get_voted_for().await, None);
}
#[tokio::test]
async fn test_term_operations() {
let (storage, _temp_dir) = create_test_storage().await;
storage.set_current_term(5).await.unwrap();
assert_eq!(storage.get_current_term().await, 5);
assert_eq!(storage.get_voted_for().await, None);
}
#[tokio::test]
async fn test_vote_operations() {
let (storage, _temp_dir) = create_test_storage().await;
storage.set_voted_for(Some(2)).await.unwrap();
assert_eq!(storage.get_voted_for().await, Some(2));
storage.set_voted_for(None).await.unwrap();
assert_eq!(storage.get_voted_for().await, None);
}
#[tokio::test]
async fn test_log_operations() {
let (storage, _temp_dir) = create_test_storage().await;
let command = RdfCommand::Insert {
subject: "s".to_string(),
predicate: "p".to_string(),
object: "o".to_string(),
};
let entry = LogEntry::new(1, 1, command);
storage.append_entries(vec![entry.clone()]).await.unwrap();
assert_eq!(storage.get_last_log_index().await, 1);
assert_eq!(storage.get_last_log_term().await, 1);
let retrieved = storage.get_log_entry(1).await.unwrap();
assert_eq!(retrieved.index, 1);
assert_eq!(retrieved.term, 1);
let entries = storage.get_log_entries(1, 2).await;
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].index, 1);
}
#[tokio::test]
async fn test_commit_operations() {
let (storage, _temp_dir) = create_test_storage().await;
storage.set_commit_index(5).await.unwrap();
assert_eq!(storage.get_commit_index().await, 5);
storage.set_last_applied(3).await.unwrap();
assert_eq!(storage.get_last_applied().await, 3);
}
#[tokio::test]
async fn test_application_state() {
let temp_dir = TempDir::new().unwrap();
let config = StorageConfig {
data_dir: temp_dir.path().to_string_lossy().to_string(),
sync_writes: false,
max_log_entries: 100,
compress_snapshots: false,
backup_retention: 2,
enable_corruption_detection: false,
enable_crash_recovery: false,
enable_wal: false,
};
let node_id = std::process::id() as u64
+ std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
println!("Creating storage...");
let storage = PersistentStorage::new(node_id, config).await.unwrap();
println!("Storage created");
println!("Getting app state...");
let app_state = storage.get_app_state().await;
println!("App state retrieved, length: {}", app_state.len());
assert_eq!(app_state.len(), 0);
println!("Modifying app state directly...");
{
let mut app_state = storage.app_state.write().await;
println!("Got write lock");
let command = RdfCommand::Insert {
subject: "s".to_string(),
predicate: "p".to_string(),
object: "o".to_string(),
};
app_state.apply_command(&command);
println!("Applied command to in-memory state");
}
println!("Getting updated app state...");
let app_state = storage.get_app_state().await;
println!("Updated app state retrieved, length: {}", app_state.len());
assert_eq!(app_state.len(), 1);
println!("Test completed successfully");
}
#[tokio::test]
async fn test_log_truncation() {
let (storage, _temp_dir) = create_test_storage().await;
for i in 1..=5 {
let command = RdfCommand::Insert {
subject: format!("s{i}"),
predicate: "p".to_string(),
object: "o".to_string(),
};
let entry = LogEntry::new(i, 1, command);
storage.append_entries(vec![entry]).await.unwrap();
}
assert_eq!(storage.get_last_log_index().await, 5);
storage.truncate_log(3).await.unwrap();
assert_eq!(storage.get_last_log_index().await, 2);
assert!(storage.get_log_entry(3).await.is_none());
assert!(storage.get_log_entry(4).await.is_none());
assert!(storage.get_log_entry(5).await.is_none());
assert!(storage.get_log_entry(1).await.is_some());
assert!(storage.get_log_entry(2).await.is_some());
}
#[tokio::test]
async fn test_snapshot_operations() {
let (storage, _temp_dir) = create_test_storage().await;
{
let mut app_state = storage.app_state.write().await;
let command = RdfCommand::Insert {
subject: "s".to_string(),
predicate: "p".to_string(),
object: "o".to_string(),
};
app_state.apply_command(&command);
}
let metadata = storage.create_snapshot().await.unwrap();
assert!(metadata.size > 0);
let loaded = storage.load_snapshot().await.unwrap();
assert!(loaded.is_some());
let (loaded_metadata, loaded_state) = loaded.unwrap();
assert_eq!(loaded_metadata.size, metadata.size);
assert_eq!(loaded_state.len(), 1);
}
#[tokio::test]
async fn test_compaction_check() {
let temp_dir = TempDir::new().unwrap();
let config = StorageConfig {
data_dir: temp_dir.path().to_string_lossy().to_string(),
max_log_entries: 3, ..Default::default()
};
let storage = PersistentStorage::new(1, config).await.unwrap();
assert!(!storage.needs_compaction().await);
for i in 1..=5 {
let command = RdfCommand::Insert {
subject: format!("s{i}"),
predicate: "p".to_string(),
object: "o".to_string(),
};
let entry = LogEntry::new(i, 1, command);
storage.append_entries(vec![entry]).await.unwrap();
}
assert!(storage.needs_compaction().await);
}
#[tokio::test]
async fn test_storage_stats() {
let (storage, _temp_dir) = create_test_storage().await;
let stats = storage.get_stats().await;
assert!(stats.node_id > 0);
assert_eq!(stats.log_entries, 0);
assert_eq!(stats.current_term, 0);
assert_eq!(stats.triple_count, 0);
}
#[tokio::test]
async fn test_persistence() {
let temp_dir = TempDir::new().unwrap();
let config = StorageConfig {
data_dir: temp_dir.path().to_string_lossy().to_string(),
sync_writes: false,
enable_corruption_detection: false,
enable_crash_recovery: false,
enable_wal: false,
..Default::default()
};
{
let storage = PersistentStorage::new(1, config.clone()).await.unwrap();
storage.set_current_term(5).await.unwrap();
storage.set_voted_for(Some(2)).await.unwrap();
let command = RdfCommand::Insert {
subject: "s".to_string(),
predicate: "p".to_string(),
object: "o".to_string(),
};
let entry = LogEntry::new(1, 1, command.clone());
storage.append_entries(vec![entry]).await.unwrap();
{
let mut app_state = storage.app_state.write().await;
app_state.apply_command(&command);
}
}
{
let storage = PersistentStorage::new(1, config).await.unwrap();
assert_eq!(storage.get_current_term().await, 5);
assert_eq!(storage.get_voted_for().await, Some(2));
assert_eq!(storage.get_last_log_index().await, 1);
}
}
#[test]
fn test_storage_error_display() {
let err = StorageError::Corruption {
file: "log.dat".to_string(),
message: "checksum mismatch".to_string(),
};
assert!(err
.to_string()
.contains("Corruption detected in log.dat: checksum mismatch"));
let err = StorageError::LogEntryNotFound { index: 42 };
assert!(err.to_string().contains("Log entry not found at index 42"));
let err = StorageError::InvalidRange { start: 10, end: 5 };
assert!(err.to_string().contains("Invalid log range: 10 to 5"));
}
}