use common::{DakeraError, NamespaceId, Result, Vector};
use serde::{Deserialize, Serialize};
use std::fs::{self, File};
use std::io::{BufReader, BufWriter};
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use std::sync::atomic::{AtomicU64, Ordering};
use crate::traits::VectorStorage;
static SNAPSHOT_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone)]
pub struct SnapshotConfig {
pub snapshot_dir: PathBuf,
pub max_snapshots: usize,
pub compression_enabled: bool,
pub include_metadata: bool,
}
impl Default for SnapshotConfig {
fn default() -> Self {
Self {
snapshot_dir: PathBuf::from("./data/snapshots"),
max_snapshots: 10,
compression_enabled: true,
include_metadata: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotMetadata {
pub id: String,
pub created_at: u64,
pub description: Option<String>,
pub namespaces: Vec<String>,
pub total_vectors: u64,
pub size_bytes: u64,
pub snapshot_type: SnapshotType,
pub parent_id: Option<String>,
pub version: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SnapshotType {
Full,
Incremental,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NamespaceSnapshot {
pub namespace: String,
pub vector_count: usize,
pub dimension: Option<usize>,
pub vectors: Vec<SerializedVector>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedVector {
pub id: String,
pub values: Vec<f32>,
pub metadata: Option<serde_json::Value>,
}
impl From<&Vector> for SerializedVector {
fn from(v: &Vector) -> Self {
Self {
id: v.id.clone(),
values: v.values.clone(),
metadata: v.metadata.clone(),
}
}
}
impl From<SerializedVector> for Vector {
fn from(sv: SerializedVector) -> Self {
Vector {
id: sv.id,
values: sv.values,
metadata: sv.metadata,
ttl_seconds: None,
expires_at: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotData {
pub metadata: SnapshotMetadata,
pub namespaces: Vec<NamespaceSnapshot>,
}
pub struct SnapshotManager {
config: SnapshotConfig,
}
impl SnapshotManager {
pub fn new(config: SnapshotConfig) -> Result<Self> {
fs::create_dir_all(&config.snapshot_dir)
.map_err(|e| DakeraError::Storage(format!("Failed to create snapshot dir: {}", e)))?;
Ok(Self { config })
}
pub async fn create_snapshot<S: VectorStorage>(
&self,
storage: &S,
description: Option<String>,
) -> Result<SnapshotMetadata> {
let snapshot_id = self.generate_snapshot_id();
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock is before UNIX epoch")
.as_secs();
let namespaces = storage.list_namespaces().await?;
let mut namespace_snapshots = Vec::new();
let mut total_vectors = 0u64;
for namespace in &namespaces {
let vectors = storage.get_all(namespace).await?;
let dimension = storage.dimension(namespace).await?;
total_vectors += vectors.len() as u64;
let serialized: Vec<SerializedVector> =
vectors.iter().map(SerializedVector::from).collect();
namespace_snapshots.push(NamespaceSnapshot {
namespace: namespace.clone(),
vector_count: serialized.len(),
dimension,
vectors: serialized,
});
}
let metadata = SnapshotMetadata {
id: snapshot_id.clone(),
created_at,
description,
namespaces: namespaces.clone(),
total_vectors,
size_bytes: 0, snapshot_type: SnapshotType::Full,
parent_id: None,
version: "1.0.0".to_string(),
};
let snapshot_data = SnapshotData {
metadata: metadata.clone(),
namespaces: namespace_snapshots,
};
let size_bytes = self.save_snapshot(&snapshot_id, &snapshot_data)?;
let mut final_metadata = metadata;
final_metadata.size_bytes = size_bytes;
self.save_metadata(&snapshot_id, &final_metadata)?;
self.cleanup_old_snapshots()?;
Ok(final_metadata)
}
pub async fn create_incremental_snapshot<S: VectorStorage>(
&self,
storage: &S,
parent_id: &str,
changed_namespaces: &[NamespaceId],
description: Option<String>,
) -> Result<SnapshotMetadata> {
if !self.snapshot_exists(parent_id) {
return Err(DakeraError::Storage(format!(
"Parent snapshot not found: {}",
parent_id
)));
}
let snapshot_id = self.generate_snapshot_id();
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock is before UNIX epoch")
.as_secs();
let mut namespace_snapshots = Vec::new();
let mut total_vectors = 0u64;
for namespace in changed_namespaces {
let vectors = storage.get_all(namespace).await?;
let dimension = storage.dimension(namespace).await?;
total_vectors += vectors.len() as u64;
let serialized: Vec<SerializedVector> =
vectors.iter().map(SerializedVector::from).collect();
namespace_snapshots.push(NamespaceSnapshot {
namespace: namespace.clone(),
vector_count: serialized.len(),
dimension,
vectors: serialized,
});
}
let metadata = SnapshotMetadata {
id: snapshot_id.clone(),
created_at,
description,
namespaces: changed_namespaces.to_vec(),
total_vectors,
size_bytes: 0,
snapshot_type: SnapshotType::Incremental,
parent_id: Some(parent_id.to_string()),
version: "1.0.0".to_string(),
};
let snapshot_data = SnapshotData {
metadata: metadata.clone(),
namespaces: namespace_snapshots,
};
let size_bytes = self.save_snapshot(&snapshot_id, &snapshot_data)?;
let mut final_metadata = metadata;
final_metadata.size_bytes = size_bytes;
self.save_metadata(&snapshot_id, &final_metadata)?;
self.cleanup_old_snapshots()?;
Ok(final_metadata)
}
pub async fn restore_snapshot<S: VectorStorage>(
&self,
storage: &S,
snapshot_id: &str,
) -> Result<RestoreResult> {
let snapshot_data = self.load_snapshot(snapshot_id)?;
let mut namespaces_restored = 0;
let mut vectors_restored = 0u64;
if snapshot_data.metadata.snapshot_type == SnapshotType::Incremental {
if let Some(parent_id) = &snapshot_data.metadata.parent_id {
let parent_result = Box::pin(self.restore_snapshot(storage, parent_id)).await?;
namespaces_restored += parent_result.namespaces_restored;
vectors_restored += parent_result.vectors_restored;
}
}
for ns_snapshot in &snapshot_data.namespaces {
storage.ensure_namespace(&ns_snapshot.namespace).await?;
let vectors: Vec<Vector> = ns_snapshot
.vectors
.iter()
.cloned()
.map(Vector::from)
.collect();
storage.upsert(&ns_snapshot.namespace, vectors).await?;
namespaces_restored += 1;
vectors_restored += ns_snapshot.vector_count as u64;
}
Ok(RestoreResult {
snapshot_id: snapshot_id.to_string(),
namespaces_restored,
vectors_restored,
})
}
pub fn list_snapshots(&self) -> Result<Vec<SnapshotMetadata>> {
let mut snapshots = Vec::new();
if let Ok(entries) = fs::read_dir(&self.config.snapshot_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map(|e| e == "meta").unwrap_or(false) {
if let Ok(metadata) = self.load_metadata_from_path(&path) {
snapshots.push(metadata);
}
}
}
}
snapshots.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(snapshots)
}
pub fn get_snapshot_metadata(&self, snapshot_id: &str) -> Result<SnapshotMetadata> {
let meta_path = self.metadata_path(snapshot_id);
self.load_metadata_from_path(&meta_path)
}
pub fn delete_snapshot(&self, snapshot_id: &str) -> Result<bool> {
let snapshot_path = self.snapshot_path(snapshot_id);
let meta_path = self.metadata_path(snapshot_id);
let mut deleted = false;
if snapshot_path.exists() {
if let Err(e) = fs::remove_file(&snapshot_path) {
tracing::warn!(
path = %snapshot_path.display(),
error = %e,
"Failed to remove snapshot file"
);
} else {
deleted = true;
}
}
if meta_path.exists() {
if let Err(e) = fs::remove_file(&meta_path) {
tracing::warn!(
path = %meta_path.display(),
error = %e,
"Failed to remove snapshot metadata file"
);
} else {
deleted = true;
}
}
Ok(deleted)
}
pub fn snapshot_exists(&self, snapshot_id: &str) -> bool {
self.snapshot_path(snapshot_id).exists()
}
fn generate_snapshot_id(&self) -> String {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock is before UNIX epoch")
.as_millis();
let counter = SNAPSHOT_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("snap_{}_{}", timestamp, counter)
}
fn snapshot_path(&self, snapshot_id: &str) -> PathBuf {
self.config
.snapshot_dir
.join(format!("{}.snap", snapshot_id))
}
fn metadata_path(&self, snapshot_id: &str) -> PathBuf {
self.config
.snapshot_dir
.join(format!("{}.meta", snapshot_id))
}
fn save_snapshot(&self, snapshot_id: &str, data: &SnapshotData) -> Result<u64> {
let path = self.snapshot_path(snapshot_id);
let file = File::create(&path)
.map_err(|e| DakeraError::Storage(format!("Failed to create snapshot: {}", e)))?;
let writer = BufWriter::new(file);
if self.config.compression_enabled {
serde_json::to_writer(writer, data)
.map_err(|e| DakeraError::Storage(format!("Snapshot serialize error: {}", e)))?;
} else {
serde_json::to_writer_pretty(writer, data)
.map_err(|e| DakeraError::Storage(format!("Snapshot serialize error: {}", e)))?;
}
let metadata = fs::metadata(&path)
.map_err(|e| DakeraError::Storage(format!("Failed to get snapshot size: {}", e)))?;
Ok(metadata.len())
}
fn load_snapshot(&self, snapshot_id: &str) -> Result<SnapshotData> {
let path = self.snapshot_path(snapshot_id);
let file = File::open(&path)
.map_err(|e| DakeraError::Storage(format!("Failed to open snapshot: {}", e)))?;
let reader = BufReader::new(file);
serde_json::from_reader(reader)
.map_err(|e| DakeraError::Storage(format!("Snapshot deserialize error: {}", e)))
}
fn save_metadata(&self, snapshot_id: &str, metadata: &SnapshotMetadata) -> Result<()> {
let path = self.metadata_path(snapshot_id);
let file = File::create(&path)
.map_err(|e| DakeraError::Storage(format!("Failed to create metadata: {}", e)))?;
let writer = BufWriter::new(file);
serde_json::to_writer_pretty(writer, metadata)
.map_err(|e| DakeraError::Storage(format!("Metadata serialize error: {}", e)))?;
Ok(())
}
fn load_metadata_from_path(&self, path: &Path) -> Result<SnapshotMetadata> {
let file = File::open(path)
.map_err(|e| DakeraError::Storage(format!("Failed to open metadata: {}", e)))?;
let reader = BufReader::new(file);
serde_json::from_reader(reader)
.map_err(|e| DakeraError::Storage(format!("Metadata deserialize error: {}", e)))
}
fn cleanup_old_snapshots(&self) -> Result<()> {
let mut snapshots = self.list_snapshots()?;
if snapshots.len() > self.config.max_snapshots {
let to_remove = snapshots.split_off(self.config.max_snapshots);
let mut deleted_ids = std::collections::HashSet::new();
for snapshot in &to_remove {
let is_parent_of_kept = snapshots
.iter()
.any(|s| s.parent_id.as_ref() == Some(&snapshot.id));
let is_parent_of_remaining = to_remove.iter().any(|s| {
s.parent_id.as_ref() == Some(&snapshot.id) && !deleted_ids.contains(&s.id)
});
if !is_parent_of_kept && !is_parent_of_remaining {
self.delete_snapshot(&snapshot.id)?;
deleted_ids.insert(snapshot.id.clone());
}
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct RestoreResult {
pub snapshot_id: String,
pub namespaces_restored: usize,
pub vectors_restored: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::InMemoryStorage;
use tempfile::TempDir;
fn test_config(dir: &Path) -> SnapshotConfig {
SnapshotConfig {
snapshot_dir: dir.to_path_buf(),
max_snapshots: 5,
compression_enabled: false,
include_metadata: true,
}
}
fn create_test_vector(id: &str, dim: usize) -> Vector {
Vector {
id: id.to_string(),
values: vec![1.0; dim],
metadata: None,
ttl_seconds: None,
expires_at: None,
}
}
#[tokio::test]
async fn test_create_snapshot() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(temp_dir.path());
let manager = SnapshotManager::new(config).unwrap();
let storage = InMemoryStorage::new();
storage.ensure_namespace(&"test".to_string()).await.unwrap();
storage
.upsert(
&"test".to_string(),
vec![create_test_vector("v1", 4), create_test_vector("v2", 4)],
)
.await
.unwrap();
let metadata = manager
.create_snapshot(&storage, Some("Test snapshot".to_string()))
.await
.unwrap();
assert_eq!(metadata.total_vectors, 2);
assert_eq!(metadata.namespaces.len(), 1);
assert_eq!(metadata.snapshot_type, SnapshotType::Full);
}
#[tokio::test]
async fn test_restore_snapshot() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(temp_dir.path());
let manager = SnapshotManager::new(config).unwrap();
let storage = InMemoryStorage::new();
storage.ensure_namespace(&"test".to_string()).await.unwrap();
storage
.upsert(&"test".to_string(), vec![create_test_vector("v1", 4)])
.await
.unwrap();
let metadata = manager.create_snapshot(&storage, None).await.unwrap();
storage
.delete(&"test".to_string(), &["v1".to_string()])
.await
.unwrap();
assert_eq!(storage.count(&"test".to_string()).await.unwrap(), 0);
let result = manager
.restore_snapshot(&storage, &metadata.id)
.await
.unwrap();
assert_eq!(result.vectors_restored, 1);
assert_eq!(storage.count(&"test".to_string()).await.unwrap(), 1);
}
#[tokio::test]
async fn test_list_snapshots() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(temp_dir.path());
let manager = SnapshotManager::new(config).unwrap();
let storage = InMemoryStorage::new();
storage.ensure_namespace(&"test".to_string()).await.unwrap();
storage
.upsert(&"test".to_string(), vec![create_test_vector("v1", 4)])
.await
.unwrap();
for i in 0..3 {
manager
.create_snapshot(&storage, Some(format!("Snapshot {}", i)))
.await
.unwrap();
}
let snapshots = manager.list_snapshots().unwrap();
assert_eq!(snapshots.len(), 3);
assert!(snapshots[0].created_at >= snapshots[1].created_at);
}
#[tokio::test]
async fn test_delete_snapshot() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(temp_dir.path());
let manager = SnapshotManager::new(config).unwrap();
let storage = InMemoryStorage::new();
storage.ensure_namespace(&"test".to_string()).await.unwrap();
storage
.upsert(&"test".to_string(), vec![create_test_vector("v1", 4)])
.await
.unwrap();
let metadata = manager.create_snapshot(&storage, None).await.unwrap();
assert!(manager.snapshot_exists(&metadata.id));
manager.delete_snapshot(&metadata.id).unwrap();
assert!(!manager.snapshot_exists(&metadata.id));
}
#[tokio::test]
async fn test_snapshot_cleanup() {
let temp_dir = TempDir::new().unwrap();
let mut config = test_config(temp_dir.path());
config.max_snapshots = 3;
let manager = SnapshotManager::new(config).unwrap();
let storage = InMemoryStorage::new();
storage.ensure_namespace(&"test".to_string()).await.unwrap();
storage
.upsert(&"test".to_string(), vec![create_test_vector("v1", 4)])
.await
.unwrap();
for _ in 0..5 {
manager.create_snapshot(&storage, None).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
let snapshots = manager.list_snapshots().unwrap();
assert!(snapshots.len() <= 3);
}
}