use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::hnsw::{HnswConfig, HnswIndex};
use crate::ivf::{IndexedVector, IvfConfig, IvfIndex};
use crate::pq::ProductQuantizer;
use crate::spfresh::{Cluster, SpFreshConfig, SpFreshIndex};
use common::{Vector, VectorId};
use std::collections::{HashMap, HashSet};
pub use storage::IndexType;
pub trait Persistable: Sized {
type Snapshot: Serialize + DeserializeOwned;
fn to_snapshot(&self) -> Self::Snapshot;
fn from_snapshot(snapshot: Self::Snapshot) -> Result<Self, String>;
fn to_bytes(&self) -> Result<Vec<u8>, String> {
let snapshot = self.to_snapshot();
serde_json::to_vec(&snapshot).map_err(|e| format!("Failed to serialize index: {}", e))
}
fn from_bytes(data: &[u8]) -> Result<Self, String> {
let snapshot: Self::Snapshot = serde_json::from_slice(data)
.map_err(|e| format!("Failed to deserialize index: {}", e))?;
Self::from_snapshot(snapshot)
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PQSnapshot {
pub quantizer: ProductQuantizer,
}
impl Persistable for ProductQuantizer {
type Snapshot = PQSnapshot;
fn to_snapshot(&self) -> PQSnapshot {
PQSnapshot {
quantizer: self.clone(),
}
}
fn from_snapshot(snapshot: PQSnapshot) -> Result<Self, String> {
Ok(snapshot.quantizer)
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IvfTrainingSnapshot {
pub config: IvfConfig,
pub dimension: usize,
pub centroids: Vec<Vec<f32>>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SpFreshTrainingSnapshot {
pub config: SpFreshConfig,
pub dimension: usize,
pub centroids: Vec<Vec<f32>>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct HnswConfigSnapshot {
pub config: HnswConfig,
pub dimension: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializableHnswNode {
pub id: String,
pub vector: Vec<f32>,
pub connections: Vec<Vec<usize>>,
pub max_layer: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswFullSnapshot {
pub config: HnswConfig,
pub dimension: usize,
pub nodes: Vec<SerializableHnswNode>,
pub entry_point: Option<usize>,
pub max_level: usize,
}
impl Persistable for HnswIndex {
type Snapshot = HnswFullSnapshot;
fn to_snapshot(&self) -> HnswFullSnapshot {
let node_snapshots = self.nodes_read();
let serializable_nodes: Vec<SerializableHnswNode> = node_snapshots
.into_iter()
.map(|node| SerializableHnswNode {
id: node.id,
vector: node.vector,
connections: node.connections,
max_layer: node.max_layer,
})
.collect();
HnswFullSnapshot {
config: self.config().clone(),
dimension: self.dimension().unwrap_or(0),
nodes: serializable_nodes,
entry_point: self.entry_point(),
max_level: self.max_level(),
}
}
fn from_snapshot(snapshot: HnswFullSnapshot) -> Result<Self, String> {
HnswIndex::from_snapshot(snapshot)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializableIndexedVector {
pub id: String,
pub values: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IvfFullSnapshot {
pub config: IvfConfig,
pub dimension: Option<usize>,
pub centroids: Vec<Vec<f32>>,
pub inverted_lists: HashMap<usize, Vec<IndexedVector>>,
pub vector_count: usize,
pub is_trained: bool,
}
impl Persistable for IvfIndex {
type Snapshot = IvfFullSnapshot;
fn to_snapshot(&self) -> IvfFullSnapshot {
IvfFullSnapshot {
config: self.config().clone(),
dimension: self.dimension(),
centroids: self.centroids_read(),
inverted_lists: self.inverted_lists_read(),
vector_count: self.len(),
is_trained: self.is_trained(),
}
}
fn from_snapshot(snapshot: IvfFullSnapshot) -> Result<Self, String> {
IvfIndex::from_snapshot(snapshot)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpFreshFullSnapshot {
pub config: SpFreshConfig,
pub clusters: Vec<Cluster>,
pub vector_cluster_map: HashMap<VectorId, usize>,
pub global_tombstones: HashSet<VectorId>,
pub pending_vectors: Vec<Vector>,
pub trained: bool,
pub dimension: Option<usize>,
}
impl Persistable for SpFreshIndex {
type Snapshot = SpFreshFullSnapshot;
fn to_snapshot(&self) -> SpFreshFullSnapshot {
SpFreshFullSnapshot {
config: self.config().clone(),
clusters: self.clusters_read(),
vector_cluster_map: self.vector_cluster_map_read(),
global_tombstones: self.global_tombstones_read(),
pending_vectors: self.pending_vectors_read(),
trained: self.is_trained(),
dimension: self.dimension(),
}
}
fn from_snapshot(snapshot: SpFreshFullSnapshot) -> Result<Self, String> {
SpFreshIndex::from_snapshot(snapshot)
}
}
pub struct IndexPersistenceManager<S> {
storage: S,
}
impl<S> IndexPersistenceManager<S> {
pub fn new(storage: S) -> Self {
Self { storage }
}
}
impl<S: storage::IndexStorage> IndexPersistenceManager<S> {
pub async fn save_hnsw(
&self,
namespace: &common::NamespaceId,
index: &HnswIndex,
) -> common::Result<()> {
let bytes = index.to_bytes().map_err(common::DakeraError::Storage)?;
self.storage
.save_index(namespace, storage::IndexType::Hnsw, bytes)
.await
}
pub async fn load_hnsw(
&self,
namespace: &common::NamespaceId,
) -> common::Result<Option<HnswIndex>> {
match self
.storage
.load_index(namespace, storage::IndexType::Hnsw)
.await?
{
Some(bytes) => {
let index = HnswIndex::from_bytes(&bytes).map_err(common::DakeraError::Storage)?;
Ok(Some(index))
}
None => Ok(None),
}
}
pub async fn save_pq(
&self,
namespace: &common::NamespaceId,
quantizer: &ProductQuantizer,
) -> common::Result<()> {
let bytes = quantizer.to_bytes().map_err(common::DakeraError::Storage)?;
self.storage
.save_index(namespace, storage::IndexType::Pq, bytes)
.await
}
pub async fn load_pq(
&self,
namespace: &common::NamespaceId,
) -> common::Result<Option<ProductQuantizer>> {
match self
.storage
.load_index(namespace, storage::IndexType::Pq)
.await?
{
Some(bytes) => {
let pq =
ProductQuantizer::from_bytes(&bytes).map_err(common::DakeraError::Storage)?;
Ok(Some(pq))
}
None => Ok(None),
}
}
pub async fn save_ivf(
&self,
namespace: &common::NamespaceId,
index: &IvfIndex,
) -> common::Result<()> {
let bytes = index.to_bytes().map_err(common::DakeraError::Storage)?;
self.storage
.save_index(namespace, storage::IndexType::Ivf, bytes)
.await
}
pub async fn load_ivf(
&self,
namespace: &common::NamespaceId,
) -> common::Result<Option<IvfIndex>> {
match self
.storage
.load_index(namespace, storage::IndexType::Ivf)
.await?
{
Some(bytes) => {
let index = IvfIndex::from_bytes(&bytes).map_err(common::DakeraError::Storage)?;
Ok(Some(index))
}
None => Ok(None),
}
}
pub async fn save_spfresh(
&self,
namespace: &common::NamespaceId,
index: &SpFreshIndex,
) -> common::Result<()> {
let bytes = index.to_bytes().map_err(common::DakeraError::Storage)?;
self.storage
.save_index(namespace, storage::IndexType::SpFresh, bytes)
.await
}
pub async fn load_spfresh(
&self,
namespace: &common::NamespaceId,
) -> common::Result<Option<SpFreshIndex>> {
match self
.storage
.load_index(namespace, storage::IndexType::SpFresh)
.await?
{
Some(bytes) => {
let index =
SpFreshIndex::from_bytes(&bytes).map_err(common::DakeraError::Storage)?;
Ok(Some(index))
}
None => Ok(None),
}
}
pub async fn index_exists(
&self,
namespace: &common::NamespaceId,
index_type: storage::IndexType,
) -> common::Result<bool> {
self.storage.index_exists(namespace, index_type).await
}
pub async fn delete_index(
&self,
namespace: &common::NamespaceId,
index_type: storage::IndexType,
) -> common::Result<bool> {
self.storage.delete_index(namespace, index_type).await
}
pub async fn list_indexes(
&self,
namespace: &common::NamespaceId,
) -> common::Result<Vec<storage::IndexType>> {
self.storage.list_indexes(namespace).await
}
}
pub fn serialize_to_bytes<T: Serialize>(value: &T) -> Result<Vec<u8>, String> {
serde_json::to_vec(value).map_err(|e| format!("Serialization failed: {}", e))
}
pub fn deserialize_from_bytes<T: DeserializeOwned>(data: &[u8]) -> Result<T, String> {
serde_json::from_slice(data).map_err(|e| format!("Deserialization failed: {}", e))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pq::PQConfig;
use common::DistanceMetric;
#[test]
fn test_pq_quantizer_persistence() {
use common::Vector;
let config = PQConfig {
num_subquantizers: 4,
num_centroids: 16,
kmeans_iterations: 10,
distance_metric: DistanceMetric::Euclidean,
};
let mut pq = ProductQuantizer::new(config, 32).unwrap();
let vectors: Vec<Vector> = (0..100)
.map(|i| Vector {
id: format!("v{}", i),
values: (0..32).map(|j| ((i + j) as f32 * 0.1).sin()).collect(),
metadata: None,
ttl_seconds: None,
expires_at: None,
})
.collect();
pq.train(&vectors).unwrap();
assert!(pq.is_trained());
let bytes = pq.to_bytes().unwrap();
assert!(!bytes.is_empty());
let restored = ProductQuantizer::from_bytes(&bytes).unwrap();
assert!(restored.is_trained());
assert_eq!(restored.dimension, 32);
assert_eq!(restored.config.num_subquantizers, 4);
assert_eq!(restored.codebooks.len(), 4);
}
#[test]
fn test_ivf_training_snapshot() {
let snapshot = IvfTrainingSnapshot {
config: IvfConfig {
n_clusters: 8,
n_probe: 2,
metric: DistanceMetric::Euclidean,
..Default::default()
},
dimension: 64,
centroids: vec![vec![0.0; 64]; 8],
};
let bytes = serialize_to_bytes(&snapshot).unwrap();
let restored: IvfTrainingSnapshot = deserialize_from_bytes(&bytes).unwrap();
assert_eq!(restored.config.n_clusters, 8);
assert_eq!(restored.dimension, 64);
assert_eq!(restored.centroids.len(), 8);
}
#[test]
fn test_hnsw_config_snapshot() {
let snapshot = HnswConfigSnapshot {
config: HnswConfig::default(),
dimension: 128,
};
let bytes = serialize_to_bytes(&snapshot).unwrap();
let restored: HnswConfigSnapshot = deserialize_from_bytes(&bytes).unwrap();
assert_eq!(restored.config.m, 16);
assert_eq!(restored.dimension, 128);
}
#[test]
fn test_spfresh_training_snapshot() {
let snapshot = SpFreshTrainingSnapshot {
config: SpFreshConfig::default(),
dimension: 32,
centroids: vec![vec![1.0; 32]; 16],
};
let bytes = serialize_to_bytes(&snapshot).unwrap();
let restored: SpFreshTrainingSnapshot = deserialize_from_bytes(&bytes).unwrap();
assert_eq!(restored.dimension, 32);
assert_eq!(restored.centroids.len(), 16);
}
#[test]
fn test_hnsw_full_persistence() {
use crate::hnsw::HnswIndex;
let index = HnswIndex::new();
for i in 0..50 {
let vector: Vec<f32> = (0..64).map(|j| ((i + j) as f32 * 0.1).sin()).collect();
index.insert(format!("vec_{}", i), vector);
}
assert_eq!(index.len(), 50);
let bytes = index.to_bytes().unwrap();
assert!(!bytes.is_empty());
let restored = HnswIndex::from_bytes(&bytes).unwrap();
assert_eq!(restored.len(), 50);
assert_eq!(restored.dimension(), index.dimension());
assert_eq!(restored.max_level(), index.max_level());
let query: Vec<f32> = (0..64).map(|j| (j as f32 * 0.1).sin()).collect();
let original_results = index.search(&query, 5);
let restored_results = restored.search(&query, 5);
assert_eq!(original_results.len(), restored_results.len());
for (orig, rest) in original_results.iter().zip(restored_results.iter()) {
assert_eq!(orig.0, rest.0); assert!((orig.1 - rest.1).abs() < 1e-6); }
}
#[test]
fn test_hnsw_empty_persistence() {
use crate::hnsw::HnswIndex;
let index = HnswIndex::new();
let bytes = index.to_bytes().unwrap();
let restored = HnswIndex::from_bytes(&bytes).unwrap();
assert_eq!(restored.len(), 0);
assert!(restored.is_empty());
}
#[test]
fn test_ivf_full_persistence() {
use crate::ivf::{IvfConfig, IvfIndex};
let training_vectors: Vec<Vec<f32>> = (0..100)
.map(|i| (0..32).map(|j| ((i + j) as f32 * 0.1).sin()).collect())
.collect();
let mut index = IvfIndex::new(IvfConfig {
n_clusters: 10,
n_probe: 3,
..Default::default()
});
index.train(&training_vectors).unwrap();
assert!(index.is_trained());
for (i, v) in training_vectors.iter().enumerate() {
index.add(format!("vec_{}", i), v.clone()).unwrap();
}
assert_eq!(index.len(), 100);
let bytes = index.to_bytes().unwrap();
assert!(!bytes.is_empty());
let restored = IvfIndex::from_bytes(&bytes).unwrap();
assert_eq!(restored.len(), 100);
assert!(restored.is_trained());
assert_eq!(restored.n_clusters(), 10);
let query = &training_vectors[0];
let original_results = index.search(query, 5).unwrap();
let restored_results = restored.search(query, 5).unwrap();
assert_eq!(original_results[0].id, restored_results[0].id);
assert_eq!(original_results[0].id, "vec_0");
}
#[test]
fn test_ivf_empty_persistence() {
use crate::ivf::{IvfConfig, IvfIndex};
let index = IvfIndex::new(IvfConfig::default());
let bytes = index.to_bytes().unwrap();
let restored = IvfIndex::from_bytes(&bytes).unwrap();
assert_eq!(restored.len(), 0);
assert!(!restored.is_trained());
}
#[test]
fn test_spfresh_full_persistence() {
use crate::spfresh::{SpFreshConfig, SpFreshIndex};
use common::Vector;
let training_vectors: Vec<Vector> = (0..100)
.map(|i| Vector {
id: format!("vec_{}", i),
values: (0..32).map(|j| ((i + j) as f32 * 0.1).sin()).collect(),
metadata: None,
ttl_seconds: None,
expires_at: None,
})
.collect();
let index = SpFreshIndex::new(SpFreshConfig {
num_clusters: 4,
n_probe: 2,
..Default::default()
});
index.train(&training_vectors).unwrap();
assert!(index.is_trained());
let stats = index.stats();
assert_eq!(stats.total_vectors, 100);
assert_eq!(stats.num_clusters, 4);
let bytes = index.to_bytes().unwrap();
assert!(!bytes.is_empty());
let restored = SpFreshIndex::from_bytes(&bytes).unwrap();
let restored_stats = restored.stats();
assert!(restored.is_trained());
assert_eq!(restored_stats.total_vectors, 100);
assert_eq!(restored_stats.num_clusters, 4);
assert_eq!(restored_stats.dimension, Some(32));
let query = &training_vectors[50].values;
let original_results = index.search(query, 10).unwrap();
let restored_results = restored.search(query, 10).unwrap();
assert_eq!(original_results.len(), restored_results.len());
let original_ids: std::collections::HashSet<_> =
original_results.iter().map(|r| &r.id).collect();
let restored_ids: std::collections::HashSet<_> =
restored_results.iter().map(|r| &r.id).collect();
let overlap = original_ids.intersection(&restored_ids).count();
assert!(
overlap >= 8,
"Expected at least 80% overlap in top-10 results, got {}/10",
overlap
);
}
#[test]
fn test_spfresh_empty_persistence() {
use crate::spfresh::{SpFreshConfig, SpFreshIndex};
let index = SpFreshIndex::new(SpFreshConfig::default());
let bytes = index.to_bytes().unwrap();
let restored = SpFreshIndex::from_bytes(&bytes).unwrap();
let stats = restored.stats();
assert_eq!(stats.total_vectors, 0);
assert!(!restored.is_trained());
}
#[test]
fn test_spfresh_persistence_with_tombstones() {
use crate::spfresh::{SpFreshConfig, SpFreshIndex};
use common::Vector;
let training_vectors: Vec<Vector> = (0..50)
.map(|i| Vector {
id: format!("vec_{}", i),
values: (0..16).map(|j| ((i + j) as f32 * 0.1).cos()).collect(),
metadata: None,
ttl_seconds: None,
expires_at: None,
})
.collect();
let index = SpFreshIndex::new(SpFreshConfig {
num_clusters: 2,
..Default::default()
});
index.train(&training_vectors).unwrap();
let ids_to_remove: Vec<String> = (0..10).map(|i| format!("vec_{}", i)).collect();
let removed = index.remove(&ids_to_remove);
assert_eq!(removed, 10);
let stats = index.stats();
assert_eq!(stats.total_vectors, 40);
assert_eq!(stats.total_tombstones, 10);
let bytes = index.to_bytes().unwrap();
let restored = SpFreshIndex::from_bytes(&bytes).unwrap();
let restored_stats = restored.stats();
assert_eq!(restored_stats.total_vectors, 40);
assert_eq!(restored_stats.total_tombstones, 10);
let results = restored.search(&training_vectors[0].values, 50).unwrap();
for result in &results {
let id_num: usize = result.id.strip_prefix("vec_").unwrap().parse().unwrap();
assert!(
id_num >= 10,
"Tombstoned vector {} appeared in results",
result.id
);
}
}
}