use crate::optimization::{
AtomicFileWriter, BinarySerializer, CorruptionDetector, SerializationConfig,
};
use crate::raft::{OxirsNodeId, RdfApp};
use crate::raft_profiling::{RaftOperation, RaftProfiler};
use crate::storage::SnapshotMetadata;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, error, info};
type TripleModification = ((String, String, String), (String, String, String));
pub const SNAPSHOT_FORMAT_VERSION: u32 = 1;
pub const CHUNK_SIZE: usize = 1024 * 1024;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnhancedSnapshotMetadata {
pub base: SnapshotMetadata,
pub format_version: u32,
pub compression: String,
pub total_chunks: u32,
pub chunk_size: usize,
pub chunk_checksums: Vec<u32>,
pub is_incremental: bool,
pub base_snapshot_id: Option<String>,
pub snapshot_id: String,
pub created_by: OxirsNodeId,
}
#[derive(Debug, Clone)]
pub struct SnapshotOptions {
pub enable_compression: bool,
pub compression_level: i32,
pub incremental: bool,
pub base_snapshot_id: Option<String>,
pub chunk_size: usize,
pub verify_integrity: bool,
}
impl Default for SnapshotOptions {
fn default() -> Self {
Self {
enable_compression: true,
compression_level: 6,
incremental: false,
base_snapshot_id: None,
chunk_size: CHUNK_SIZE,
verify_integrity: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferStatus {
pub total_bytes: u64,
pub transferred_bytes: u64,
pub current_chunk: u32,
pub total_chunks: u32,
pub transfer_rate: f64,
pub eta_seconds: u64,
pub started_at: u64,
pub last_update: u64,
}
impl TransferStatus {
pub fn new(total_bytes: u64, total_chunks: u32) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
Self {
total_bytes,
transferred_bytes: 0,
current_chunk: 0,
total_chunks,
transfer_rate: 0.0,
eta_seconds: 0,
started_at: now,
last_update: now,
}
}
pub fn update(&mut self, chunk_index: u32, chunk_size: u64) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
self.current_chunk = chunk_index;
self.transferred_bytes += chunk_size;
let elapsed = now.saturating_sub(self.started_at);
if elapsed > 0 {
self.transfer_rate = self.transferred_bytes as f64 / elapsed as f64;
let remaining_bytes = self.total_bytes.saturating_sub(self.transferred_bytes);
if self.transfer_rate > 0.0 {
self.eta_seconds = (remaining_bytes as f64 / self.transfer_rate) as u64;
}
}
self.last_update = now;
}
pub fn progress_percentage(&self) -> f64 {
if self.total_bytes == 0 {
100.0
} else {
(self.transferred_bytes as f64 / self.total_bytes as f64) * 100.0
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotDiff {
pub additions: Vec<(String, String, String)>,
pub deletions: Vec<(String, String, String)>,
pub modifications: Vec<TripleModification>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SnapshotData {
Full(RdfApp),
Diff(SnapshotDiff),
}
pub struct EnhancedSnapshotManager {
node_id: OxirsNodeId,
data_dir: PathBuf,
serializer: Arc<Mutex<BinarySerializer>>,
corruption_detector: CorruptionDetector,
active_transfers: Arc<RwLock<HashMap<String, TransferStatus>>>,
snapshot_cache: Arc<RwLock<HashMap<String, EnhancedSnapshotMetadata>>>,
profiler: Arc<RaftProfiler>,
}
impl EnhancedSnapshotManager {
pub fn new(node_id: OxirsNodeId, data_dir: PathBuf) -> Self {
let serialization_config = SerializationConfig::default();
let serializer = BinarySerializer::new(serialization_config);
let corruption_detector = CorruptionDetector::new(true);
Self {
node_id,
data_dir,
serializer: Arc::new(Mutex::new(serializer)),
corruption_detector,
active_transfers: Arc::new(RwLock::new(HashMap::new())),
snapshot_cache: Arc::new(RwLock::new(HashMap::new())),
profiler: Arc::new(RaftProfiler::new(node_id)),
}
}
pub fn profiler(&self) -> &Arc<RaftProfiler> {
&self.profiler
}
pub async fn create_snapshot(
&self,
app_state: &RdfApp,
options: SnapshotOptions,
) -> Result<EnhancedSnapshotMetadata> {
let prof_op = self
.profiler
.start_operation(RaftOperation::CreateSnapshot)
.await;
let snapshot_id = uuid::Uuid::new_v4().to_string();
info!(
"Creating enhanced snapshot {} with options: {:?}",
snapshot_id, options
);
let snapshot_dir = self.data_dir.join("snapshots").join(&snapshot_id);
tokio::fs::create_dir_all(&snapshot_dir).await?;
let snapshot_data = if options.incremental {
SnapshotData::Diff(
self.create_incremental_snapshot(app_state, &options)
.await?,
)
} else {
SnapshotData::Full(self.create_full_snapshot(app_state).await?)
};
let serializer = self.serializer.lock().await;
let serialized_data = serializer.serialize(&snapshot_data)?;
drop(serializer);
let chunks = self.create_chunks(&serialized_data, options.chunk_size);
let chunk_checksums = chunks.iter().map(|chunk| crc32fast::hash(chunk)).collect();
for (i, chunk) in chunks.iter().enumerate() {
let chunk_path = snapshot_dir.join(format!("chunk_{i:06}.dat"));
let mut writer = AtomicFileWriter::new(&chunk_path).await?;
writer.write_all(chunk).await?;
writer.commit().await?;
}
let base_metadata = SnapshotMetadata {
last_included_index: 0, last_included_term: 0, configuration: vec![self.node_id],
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
size: serialized_data.len() as u64,
checksum: {
let mut hasher = Sha256::new();
hasher.update(&serialized_data);
format!("{:x}", hasher.finalize())
},
};
let enhanced_metadata = EnhancedSnapshotMetadata {
base: base_metadata,
format_version: SNAPSHOT_FORMAT_VERSION,
compression: if options.enable_compression {
"lz4".to_string()
} else {
"none".to_string()
},
total_chunks: chunks.len() as u32,
chunk_size: options.chunk_size,
chunk_checksums,
is_incremental: options.incremental,
base_snapshot_id: options.base_snapshot_id,
snapshot_id: snapshot_id.clone(),
created_by: self.node_id,
};
let metadata_path = snapshot_dir.join("metadata.json");
let serializer = self.serializer.lock().await;
let metadata_data = serializer.serialize(&enhanced_metadata)?;
drop(serializer);
let mut writer = AtomicFileWriter::new(&metadata_path).await?;
writer.write_all(&metadata_data).await?;
writer.commit().await?;
{
let mut cache = self.snapshot_cache.write().await;
cache.insert(snapshot_id.clone(), enhanced_metadata.clone());
}
info!(
"Created enhanced snapshot {} with {} chunks",
snapshot_id,
chunks.len()
);
self.profiler
.record_memory_usage("snapshot", enhanced_metadata.base.size)
.await;
prof_op.complete().await;
Ok(enhanced_metadata)
}
pub async fn load_snapshot(
&self,
snapshot_id: &str,
) -> Result<(EnhancedSnapshotMetadata, RdfApp)> {
info!("Loading enhanced snapshot {}", snapshot_id);
{
let cache = self.snapshot_cache.read().await;
if let Some(metadata) = cache.get(snapshot_id) {
return self.load_snapshot_with_metadata(metadata).await;
}
}
let snapshot_dir = self.data_dir.join("snapshots").join(snapshot_id);
let metadata_path = snapshot_dir.join("metadata.json");
if !metadata_path.exists() {
return Err(anyhow::anyhow!("Snapshot {} not found", snapshot_id));
}
if !self
.corruption_detector
.validate_file(&metadata_path)
.await?
{
return Err(anyhow::anyhow!("Snapshot metadata corrupted"));
}
let metadata_data = tokio::fs::read(&metadata_path).await?;
let metadata: EnhancedSnapshotMetadata =
self.serializer.lock().await.deserialize(&metadata_data)?;
{
let mut cache = self.snapshot_cache.write().await;
cache.insert(snapshot_id.to_string(), metadata.clone());
}
self.load_snapshot_with_metadata(&metadata).await
}
async fn load_snapshot_with_metadata(
&self,
metadata: &EnhancedSnapshotMetadata,
) -> Result<(EnhancedSnapshotMetadata, RdfApp)> {
let snapshot_dir = self.data_dir.join("snapshots").join(&metadata.snapshot_id);
let mut all_data = Vec::with_capacity(metadata.base.size as usize);
for i in 0..metadata.total_chunks {
let chunk_path = snapshot_dir.join(format!("chunk_{i:06}.dat"));
let chunk_data = tokio::fs::read(&chunk_path).await?;
let computed_checksum = crc32fast::hash(&chunk_data);
let expected_checksum = metadata.chunk_checksums[i as usize];
if computed_checksum != expected_checksum {
return Err(anyhow::anyhow!(
"Chunk {} checksum mismatch: expected {}, got {}",
i,
expected_checksum,
computed_checksum
));
}
all_data.extend_from_slice(&chunk_data);
}
let snapshot_data: SnapshotData = self.serializer.lock().await.deserialize(&all_data)?;
let app_state: RdfApp = match snapshot_data {
SnapshotData::Full(state) => state,
SnapshotData::Diff(_diff) => {
RdfApp::default()
}
};
info!(
"Loaded enhanced snapshot {} successfully",
metadata.snapshot_id
);
Ok((metadata.clone(), app_state))
}
pub async fn transfer_snapshot(
&self,
snapshot_id: &str,
target_node: OxirsNodeId,
target_address: &str,
) -> Result<String> {
let transfer_id = uuid::Uuid::new_v4().to_string();
let metadata = {
let cache = self.snapshot_cache.read().await;
cache.get(snapshot_id).cloned()
};
let metadata = match metadata {
Some(m) => m,
None => {
let (metadata, _) = self.load_snapshot(snapshot_id).await?;
metadata
}
};
let transfer_status = TransferStatus::new(metadata.base.size, metadata.total_chunks);
{
let mut transfers = self.active_transfers.write().await;
transfers.insert(transfer_id.clone(), transfer_status);
}
let manager = self.clone();
let transfer_id_clone = transfer_id.clone();
let metadata_clone = metadata.clone();
let target_address = target_address.to_string();
tokio::spawn(async move {
if let Err(e) = manager
.execute_snapshot_transfer(
&transfer_id_clone,
&metadata_clone,
target_node,
&target_address,
)
.await
{
error!("Snapshot transfer {} failed: {}", transfer_id_clone, e);
}
});
info!(
"Started snapshot transfer {} to node {}",
transfer_id, target_node
);
Ok(transfer_id)
}
async fn execute_snapshot_transfer(
&self,
transfer_id: &str,
metadata: &EnhancedSnapshotMetadata,
_target_node: OxirsNodeId,
_target_address: &str,
) -> Result<()> {
let snapshot_dir = self.data_dir.join("snapshots").join(&metadata.snapshot_id);
for i in 0..metadata.total_chunks {
let chunk_path = snapshot_dir.join(format!("chunk_{i:06}.dat"));
let chunk_data = tokio::fs::read(&chunk_path).await?;
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
{
let mut transfers = self.active_transfers.write().await;
if let Some(status) = transfers.get_mut(transfer_id) {
status.update(i, chunk_data.len() as u64);
}
}
debug!("Transferred chunk {} of {}", i + 1, metadata.total_chunks);
}
{
let mut transfers = self.active_transfers.write().await;
if let Some(status) = transfers.get_mut(transfer_id) {
status.current_chunk = status.total_chunks;
status.transferred_bytes = status.total_bytes;
}
}
info!("Completed snapshot transfer {}", transfer_id);
Ok(())
}
pub async fn get_transfer_status(&self, transfer_id: &str) -> Option<TransferStatus> {
let transfers = self.active_transfers.read().await;
transfers.get(transfer_id).cloned()
}
pub async fn list_snapshots(&self) -> Result<Vec<EnhancedSnapshotMetadata>> {
let snapshots_dir = self.data_dir.join("snapshots");
if !snapshots_dir.exists() {
return Ok(Vec::new());
}
let mut snapshots = Vec::new();
let mut entries = tokio::fs::read_dir(&snapshots_dir).await?;
while let Some(entry) = entries.next_entry().await? {
if entry.file_type().await?.is_dir() {
let snapshot_id = entry.file_name().to_string_lossy().to_string();
if let Ok((metadata, _)) = self.load_snapshot(&snapshot_id).await {
snapshots.push(metadata);
}
}
}
snapshots.sort_by_key(|b| std::cmp::Reverse(b.base.timestamp));
Ok(snapshots)
}
pub async fn delete_snapshot(&self, snapshot_id: &str) -> Result<()> {
let snapshot_dir = self.data_dir.join("snapshots").join(snapshot_id);
if snapshot_dir.exists() {
tokio::fs::remove_dir_all(&snapshot_dir).await?;
{
let mut cache = self.snapshot_cache.write().await;
cache.remove(snapshot_id);
}
info!("Deleted snapshot {}", snapshot_id);
}
Ok(())
}
async fn create_full_snapshot(&self, app_state: &RdfApp) -> Result<RdfApp> {
Ok(app_state.clone())
}
async fn create_incremental_snapshot(
&self,
_app_state: &RdfApp,
options: &SnapshotOptions,
) -> Result<SnapshotDiff> {
if let Some(_base_snapshot_id) = &options.base_snapshot_id {
Ok(SnapshotDiff {
additions: Vec::new(),
deletions: Vec::new(),
modifications: Vec::new(),
})
} else {
Err(anyhow::anyhow!(
"Base snapshot ID required for incremental snapshot"
))
}
}
fn create_chunks(&self, data: &[u8], chunk_size: usize) -> Vec<Vec<u8>> {
data.chunks(chunk_size)
.map(|chunk| chunk.to_vec())
.collect()
}
pub async fn cleanup_old_snapshots(&self, retention_count: usize) -> Result<()> {
let snapshots = self.list_snapshots().await?;
if snapshots.len() > retention_count {
let to_delete = &snapshots[retention_count..];
for snapshot in to_delete {
self.delete_snapshot(&snapshot.snapshot_id).await?;
}
info!("Cleaned up {} old snapshots", to_delete.len());
}
Ok(())
}
}
impl Clone for EnhancedSnapshotManager {
fn clone(&self) -> Self {
Self {
node_id: self.node_id,
data_dir: self.data_dir.clone(),
serializer: Arc::new(Mutex::new(BinarySerializer::new(
SerializationConfig::default(),
))),
corruption_detector: CorruptionDetector::new(true),
active_transfers: Arc::clone(&self.active_transfers),
snapshot_cache: Arc::clone(&self.snapshot_cache),
profiler: Arc::clone(&self.profiler),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn create_test_manager() -> (EnhancedSnapshotManager, TempDir) {
let temp_dir = TempDir::new().unwrap();
let manager = EnhancedSnapshotManager::new(1, temp_dir.path().to_path_buf());
(manager, temp_dir)
}
#[tokio::test]
async fn test_create_and_load_snapshot() {
let (manager, temp_dir) = create_test_manager().await;
let mut app_state = RdfApp::default();
app_state
.triples
.insert(("s".to_string(), "p".to_string(), "o".to_string()));
let options = SnapshotOptions::default();
let metadata = manager.create_snapshot(&app_state, options).await.unwrap();
let (loaded_metadata, loaded_state) =
manager.load_snapshot(&metadata.snapshot_id).await.unwrap();
assert_eq!(loaded_metadata.snapshot_id, metadata.snapshot_id);
assert_eq!(loaded_state.triples.len(), 1);
drop(temp_dir);
}
#[tokio::test]
async fn test_snapshot_transfer() {
let (manager, _temp_dir) = create_test_manager().await;
let mut app_state = RdfApp::default();
app_state
.triples
.insert(("s".to_string(), "p".to_string(), "o".to_string()));
let options = SnapshotOptions::default();
let metadata = manager.create_snapshot(&app_state, options).await.unwrap();
let transfer_id = manager
.transfer_snapshot(&metadata.snapshot_id, 2, "127.0.0.1:8081")
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let status = manager.get_transfer_status(&transfer_id).await;
assert!(status.is_some());
}
#[tokio::test]
async fn test_list_snapshots() {
let (manager, _temp_dir) = create_test_manager().await;
let app_state = RdfApp::default();
let options = SnapshotOptions::default();
let _metadata1 = manager
.create_snapshot(&app_state, options.clone())
.await
.unwrap();
let _metadata2 = manager.create_snapshot(&app_state, options).await.unwrap();
let snapshots = manager.list_snapshots().await.unwrap();
assert_eq!(snapshots.len(), 2);
}
#[test]
fn test_transfer_status() {
let mut status = TransferStatus::new(1000, 10);
assert_eq!(status.progress_percentage(), 0.0);
status.update(5, 100);
assert_eq!(status.transferred_bytes, 100);
assert_eq!(status.current_chunk, 5);
assert!(status.progress_percentage() > 0.0);
}
#[test]
fn test_create_chunks() {
let manager = EnhancedSnapshotManager::new(1, PathBuf::from("/tmp"));
let data = vec![1u8; 250];
let chunks = manager.create_chunks(&data, 100);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].len(), 100);
assert_eq!(chunks[1].len(), 100);
assert_eq!(chunks[2].len(), 50);
}
}