use async_trait::async_trait;
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use sha2::{Digest, Sha256};
use std::io::{Read, Write};
use std::path::PathBuf;
use tokio::fs;
use crate::error::{Result, SnapshotError};
use crate::snapshot::{Snapshot, SnapshotData};
#[async_trait]
pub trait SnapshotStorage: Send + Sync {
async fn save(&self, snapshot: &SnapshotData) -> Result<Snapshot>;
async fn load(&self, id: &str) -> Result<SnapshotData>;
async fn list(&self) -> Result<Vec<Snapshot>>;
async fn delete(&self, id: &str) -> Result<()>;
}
pub struct LocalStorage {
base_path: PathBuf,
}
impl LocalStorage {
pub fn new(base_path: PathBuf) -> Self {
Self { base_path }
}
fn snapshot_path(&self, id: &str) -> PathBuf {
self.base_path.join(format!("{}.snapshot.gz", id))
}
fn metadata_path(&self, id: &str) -> PathBuf {
self.base_path.join(format!("{}.metadata.json", id))
}
fn compress(data: &[u8]) -> Result<Vec<u8>> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder
.write_all(data)
.map_err(|e| SnapshotError::compression(format!("Compression failed: {}", e)))?;
encoder
.finish()
.map_err(|e| SnapshotError::compression(format!("Finish compression failed: {}", e)))
}
fn decompress(data: &[u8]) -> Result<Vec<u8>> {
let mut decoder = GzDecoder::new(data);
let mut decompressed = Vec::new();
decoder
.read_to_end(&mut decompressed)
.map_err(|e| SnapshotError::compression(format!("Decompression failed: {}", e)))?;
Ok(decompressed)
}
fn calculate_checksum(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
format!("{:x}", hasher.finalize())
}
async fn ensure_dir(&self) -> Result<()> {
if !self.base_path.exists() {
fs::create_dir_all(&self.base_path).await?;
}
Ok(())
}
}
#[async_trait]
impl SnapshotStorage for LocalStorage {
async fn save(&self, snapshot_data: &SnapshotData) -> Result<Snapshot> {
self.ensure_dir().await?;
let id = snapshot_data.id().to_string();
let snapshot_path = self.snapshot_path(&id);
let metadata_path = self.metadata_path(&id);
let config = bincode::config::standard();
let serialized = bincode::encode_to_vec(snapshot_data, config)
.map_err(|e| SnapshotError::SerializationError(e.to_string()))?;
let checksum = Self::calculate_checksum(&serialized);
let compressed = Self::compress(&serialized)?;
let size_bytes = compressed.len() as u64;
fs::write(&snapshot_path, &compressed).await?;
let created_at = chrono::DateTime::parse_from_rfc3339(&snapshot_data.metadata.created_at)
.map_err(|e| SnapshotError::storage(format!("Invalid timestamp: {}", e)))?
.with_timezone(&chrono::Utc);
let snapshot = Snapshot {
id: id.clone(),
collection_name: snapshot_data.collection_name().to_string(),
created_at,
vectors_count: snapshot_data.vectors_count(),
checksum,
size_bytes,
};
let metadata_json = serde_json::to_string_pretty(&snapshot)?;
fs::write(&metadata_path, metadata_json).await?;
Ok(snapshot)
}
async fn load(&self, id: &str) -> Result<SnapshotData> {
let snapshot_path = self.snapshot_path(id);
let metadata_path = self.metadata_path(id);
if !snapshot_path.exists() {
return Err(SnapshotError::SnapshotNotFound(id.to_string()));
}
let metadata_json = fs::read_to_string(&metadata_path).await?;
let snapshot: Snapshot = serde_json::from_str(&metadata_json)?;
let compressed = fs::read(&snapshot_path).await?;
let decompressed = Self::decompress(&compressed)?;
let actual_checksum = Self::calculate_checksum(&decompressed);
if actual_checksum != snapshot.checksum {
return Err(SnapshotError::InvalidChecksum {
expected: snapshot.checksum,
actual: actual_checksum,
});
}
let config = bincode::config::standard();
let (snapshot_data, _): (SnapshotData, usize) =
bincode::decode_from_slice(&decompressed, config)
.map_err(|e| SnapshotError::SerializationError(e.to_string()))?;
Ok(snapshot_data)
}
async fn list(&self) -> Result<Vec<Snapshot>> {
self.ensure_dir().await?;
let mut snapshots = Vec::new();
let mut entries = fs::read_dir(&self.base_path).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if let Some(extension) = path.extension() {
if extension == "json" {
if let Some(file_name) = path.file_stem() {
let file_name_str = file_name.to_string_lossy();
if file_name_str.ends_with(".metadata") {
let contents = fs::read_to_string(&path).await?;
if let Ok(snapshot) = serde_json::from_str::<Snapshot>(&contents) {
snapshots.push(snapshot);
}
}
}
}
}
}
snapshots.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(snapshots)
}
async fn delete(&self, id: &str) -> Result<()> {
let snapshot_path = self.snapshot_path(id);
let metadata_path = self.metadata_path(id);
if !snapshot_path.exists() {
return Err(SnapshotError::SnapshotNotFound(id.to_string()));
}
fs::remove_file(&snapshot_path).await?;
if metadata_path.exists() {
fs::remove_file(&metadata_path).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::snapshot::{CollectionConfig, DistanceMetric, VectorRecord};
#[test]
fn test_compression_roundtrip() {
let data = b"Hello, World! This is test data for compression.";
let compressed = LocalStorage::compress(data).unwrap();
let decompressed = LocalStorage::decompress(&compressed).unwrap();
assert_eq!(data.to_vec(), decompressed);
}
#[test]
fn test_checksum_calculation() {
let data = b"test data";
let checksum = LocalStorage::calculate_checksum(data);
assert_eq!(checksum.len(), 64); }
#[tokio::test]
async fn test_local_storage_roundtrip() {
let temp_dir = std::env::temp_dir().join("ruvector-snapshot-test");
let storage = LocalStorage::new(temp_dir.clone());
let config = CollectionConfig {
dimension: 3,
metric: DistanceMetric::Cosine,
hnsw_config: None,
};
let vectors = vec![
VectorRecord::new("v1".to_string(), vec![1.0, 0.0, 0.0], None),
VectorRecord::new("v2".to_string(), vec![0.0, 1.0, 0.0], None),
];
let snapshot_data = SnapshotData::new("test-collection".to_string(), config, vectors);
let id = snapshot_data.id().to_string();
let snapshot = storage.save(&snapshot_data).await.unwrap();
assert_eq!(snapshot.id, id);
assert_eq!(snapshot.vectors_count, 2);
let snapshots = storage.list().await.unwrap();
assert!(!snapshots.is_empty());
let loaded = storage.load(&id).await.unwrap();
assert_eq!(loaded.id(), id);
assert_eq!(loaded.vectors_count(), 2);
storage.delete(&id).await.unwrap();
let _ = std::fs::remove_dir_all(temp_dir);
}
}