1use serde::{de::DeserializeOwned, Deserialize, Serialize};
10
11use crate::hnsw::{HnswConfig, HnswIndex};
12use crate::ivf::{IndexedVector, IvfConfig, IvfIndex};
13use crate::pq::ProductQuantizer;
14use crate::spfresh::{Cluster, SpFreshConfig, SpFreshIndex};
15use common::{Vector, VectorId};
16use std::collections::{HashMap, HashSet};
17
18pub use storage::IndexType;
20
21pub trait Persistable: Sized {
23 type Snapshot: Serialize + DeserializeOwned;
25
26 fn to_snapshot(&self) -> Self::Snapshot;
28
29 fn from_snapshot(snapshot: Self::Snapshot) -> Result<Self, String>;
31
32 fn to_bytes(&self) -> Result<Vec<u8>, String> {
34 let snapshot = self.to_snapshot();
35 serde_json::to_vec(&snapshot).map_err(|e| format!("Failed to serialize index: {}", e))
36 }
37
38 fn from_bytes(data: &[u8]) -> Result<Self, String> {
40 let snapshot: Self::Snapshot = serde_json::from_slice(data)
41 .map_err(|e| format!("Failed to deserialize index: {}", e))?;
42 Self::from_snapshot(snapshot)
43 }
44}
45
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
48pub struct PQSnapshot {
49 pub quantizer: ProductQuantizer,
51}
52
53impl Persistable for ProductQuantizer {
54 type Snapshot = PQSnapshot;
55
56 fn to_snapshot(&self) -> PQSnapshot {
57 PQSnapshot {
58 quantizer: self.clone(),
59 }
60 }
61
62 fn from_snapshot(snapshot: PQSnapshot) -> Result<Self, String> {
63 Ok(snapshot.quantizer)
64 }
65}
66
67#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
69pub struct IvfTrainingSnapshot {
70 pub config: IvfConfig,
71 pub dimension: usize,
72 pub centroids: Vec<Vec<f32>>,
73}
74
75#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
77pub struct SpFreshTrainingSnapshot {
78 pub config: SpFreshConfig,
79 pub dimension: usize,
80 pub centroids: Vec<Vec<f32>>,
81}
82
83#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
86pub struct HnswConfigSnapshot {
87 pub config: HnswConfig,
88 pub dimension: usize,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct SerializableHnswNode {
94 pub id: String,
96 pub vector: Vec<f32>,
98 pub connections: Vec<Vec<usize>>,
100 pub max_layer: usize,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct HnswFullSnapshot {
108 pub config: HnswConfig,
110 pub dimension: usize,
112 pub nodes: Vec<SerializableHnswNode>,
114 pub entry_point: Option<usize>,
116 pub max_level: usize,
118}
119
120impl Persistable for HnswIndex {
121 type Snapshot = HnswFullSnapshot;
122
123 fn to_snapshot(&self) -> HnswFullSnapshot {
124 let node_snapshots = self.nodes_read();
125 let serializable_nodes: Vec<SerializableHnswNode> = node_snapshots
126 .into_iter()
127 .map(|node| SerializableHnswNode {
128 id: node.id,
129 vector: node.vector,
130 connections: node.connections,
131 max_layer: node.max_layer,
132 })
133 .collect();
134
135 HnswFullSnapshot {
136 config: self.config().clone(),
137 dimension: self.dimension().unwrap_or(0),
138 nodes: serializable_nodes,
139 entry_point: self.entry_point(),
140 max_level: self.max_level(),
141 }
142 }
143
144 fn from_snapshot(snapshot: HnswFullSnapshot) -> Result<Self, String> {
145 HnswIndex::from_snapshot(snapshot)
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct SerializableIndexedVector {
152 pub id: String,
154 pub values: Vec<f32>,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct IvfFullSnapshot {
162 pub config: IvfConfig,
164 pub dimension: Option<usize>,
166 pub centroids: Vec<Vec<f32>>,
168 pub inverted_lists: HashMap<usize, Vec<IndexedVector>>,
170 pub vector_count: usize,
172 pub is_trained: bool,
174}
175
176impl Persistable for IvfIndex {
177 type Snapshot = IvfFullSnapshot;
178
179 fn to_snapshot(&self) -> IvfFullSnapshot {
180 IvfFullSnapshot {
181 config: self.config().clone(),
182 dimension: self.dimension(),
183 centroids: self.centroids_read(),
184 inverted_lists: self.inverted_lists_read(),
185 vector_count: self.len(),
186 is_trained: self.is_trained(),
187 }
188 }
189
190 fn from_snapshot(snapshot: IvfFullSnapshot) -> Result<Self, String> {
191 IvfIndex::from_snapshot(snapshot)
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct SpFreshFullSnapshot {
199 pub config: SpFreshConfig,
201 pub clusters: Vec<Cluster>,
203 pub vector_cluster_map: HashMap<VectorId, usize>,
205 pub global_tombstones: HashSet<VectorId>,
207 pub pending_vectors: Vec<Vector>,
209 pub trained: bool,
211 pub dimension: Option<usize>,
213}
214
215impl Persistable for SpFreshIndex {
216 type Snapshot = SpFreshFullSnapshot;
217
218 fn to_snapshot(&self) -> SpFreshFullSnapshot {
219 SpFreshFullSnapshot {
220 config: self.config().clone(),
221 clusters: self.clusters_read(),
222 vector_cluster_map: self.vector_cluster_map_read(),
223 global_tombstones: self.global_tombstones_read(),
224 pending_vectors: self.pending_vectors_read(),
225 trained: self.is_trained(),
226 dimension: self.dimension(),
227 }
228 }
229
230 fn from_snapshot(snapshot: SpFreshFullSnapshot) -> Result<Self, String> {
231 SpFreshIndex::from_snapshot(snapshot)
232 }
233}
234
235pub struct IndexPersistenceManager<S> {
237 storage: S,
238}
239
240impl<S> IndexPersistenceManager<S> {
241 pub fn new(storage: S) -> Self {
243 Self { storage }
244 }
245}
246
247impl<S: storage::IndexStorage> IndexPersistenceManager<S> {
248 pub async fn save_hnsw(
250 &self,
251 namespace: &common::NamespaceId,
252 index: &HnswIndex,
253 ) -> common::Result<()> {
254 let bytes = index.to_bytes().map_err(common::DakeraError::Storage)?;
255 self.storage
256 .save_index(namespace, storage::IndexType::Hnsw, bytes)
257 .await
258 }
259
260 pub async fn load_hnsw(
262 &self,
263 namespace: &common::NamespaceId,
264 ) -> common::Result<Option<HnswIndex>> {
265 match self
266 .storage
267 .load_index(namespace, storage::IndexType::Hnsw)
268 .await?
269 {
270 Some(bytes) => {
271 let index = HnswIndex::from_bytes(&bytes).map_err(common::DakeraError::Storage)?;
272 Ok(Some(index))
273 }
274 None => Ok(None),
275 }
276 }
277
278 pub async fn save_pq(
280 &self,
281 namespace: &common::NamespaceId,
282 quantizer: &ProductQuantizer,
283 ) -> common::Result<()> {
284 let bytes = quantizer.to_bytes().map_err(common::DakeraError::Storage)?;
285 self.storage
286 .save_index(namespace, storage::IndexType::Pq, bytes)
287 .await
288 }
289
290 pub async fn load_pq(
292 &self,
293 namespace: &common::NamespaceId,
294 ) -> common::Result<Option<ProductQuantizer>> {
295 match self
296 .storage
297 .load_index(namespace, storage::IndexType::Pq)
298 .await?
299 {
300 Some(bytes) => {
301 let pq =
302 ProductQuantizer::from_bytes(&bytes).map_err(common::DakeraError::Storage)?;
303 Ok(Some(pq))
304 }
305 None => Ok(None),
306 }
307 }
308
309 pub async fn save_ivf(
311 &self,
312 namespace: &common::NamespaceId,
313 index: &IvfIndex,
314 ) -> common::Result<()> {
315 let bytes = index.to_bytes().map_err(common::DakeraError::Storage)?;
316 self.storage
317 .save_index(namespace, storage::IndexType::Ivf, bytes)
318 .await
319 }
320
321 pub async fn load_ivf(
323 &self,
324 namespace: &common::NamespaceId,
325 ) -> common::Result<Option<IvfIndex>> {
326 match self
327 .storage
328 .load_index(namespace, storage::IndexType::Ivf)
329 .await?
330 {
331 Some(bytes) => {
332 let index = IvfIndex::from_bytes(&bytes).map_err(common::DakeraError::Storage)?;
333 Ok(Some(index))
334 }
335 None => Ok(None),
336 }
337 }
338
339 pub async fn save_spfresh(
341 &self,
342 namespace: &common::NamespaceId,
343 index: &SpFreshIndex,
344 ) -> common::Result<()> {
345 let bytes = index.to_bytes().map_err(common::DakeraError::Storage)?;
346 self.storage
347 .save_index(namespace, storage::IndexType::SpFresh, bytes)
348 .await
349 }
350
351 pub async fn load_spfresh(
353 &self,
354 namespace: &common::NamespaceId,
355 ) -> common::Result<Option<SpFreshIndex>> {
356 match self
357 .storage
358 .load_index(namespace, storage::IndexType::SpFresh)
359 .await?
360 {
361 Some(bytes) => {
362 let index =
363 SpFreshIndex::from_bytes(&bytes).map_err(common::DakeraError::Storage)?;
364 Ok(Some(index))
365 }
366 None => Ok(None),
367 }
368 }
369
370 pub async fn index_exists(
372 &self,
373 namespace: &common::NamespaceId,
374 index_type: storage::IndexType,
375 ) -> common::Result<bool> {
376 self.storage.index_exists(namespace, index_type).await
377 }
378
379 pub async fn delete_index(
381 &self,
382 namespace: &common::NamespaceId,
383 index_type: storage::IndexType,
384 ) -> common::Result<bool> {
385 self.storage.delete_index(namespace, index_type).await
386 }
387
388 pub async fn list_indexes(
390 &self,
391 namespace: &common::NamespaceId,
392 ) -> common::Result<Vec<storage::IndexType>> {
393 self.storage.list_indexes(namespace).await
394 }
395}
396
397pub fn serialize_to_bytes<T: Serialize>(value: &T) -> Result<Vec<u8>, String> {
399 serde_json::to_vec(value).map_err(|e| format!("Serialization failed: {}", e))
400}
401
402pub fn deserialize_from_bytes<T: DeserializeOwned>(data: &[u8]) -> Result<T, String> {
404 serde_json::from_slice(data).map_err(|e| format!("Deserialization failed: {}", e))
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use crate::pq::PQConfig;
411 use common::DistanceMetric;
412
413 #[test]
414 fn test_pq_quantizer_persistence() {
415 use common::Vector;
416
417 let config = PQConfig {
418 num_subquantizers: 4,
419 num_centroids: 16,
420 kmeans_iterations: 10,
421 distance_metric: DistanceMetric::Euclidean,
422 };
423
424 let mut pq = ProductQuantizer::new(config, 32).unwrap();
425
426 let vectors: Vec<Vector> = (0..100)
428 .map(|i| Vector {
429 id: format!("v{}", i),
430 values: (0..32).map(|j| ((i + j) as f32 * 0.1).sin()).collect(),
431 metadata: None,
432 ttl_seconds: None,
433 expires_at: None,
434 })
435 .collect();
436
437 pq.train(&vectors).unwrap();
439 assert!(pq.is_trained());
440
441 let bytes = pq.to_bytes().unwrap();
443 assert!(!bytes.is_empty());
444
445 let restored = ProductQuantizer::from_bytes(&bytes).unwrap();
447
448 assert!(restored.is_trained());
450 assert_eq!(restored.dimension, 32);
451 assert_eq!(restored.config.num_subquantizers, 4);
452 assert_eq!(restored.codebooks.len(), 4);
453 }
454
455 #[test]
456 fn test_ivf_training_snapshot() {
457 let snapshot = IvfTrainingSnapshot {
458 config: IvfConfig {
459 n_clusters: 8,
460 n_probe: 2,
461 metric: DistanceMetric::Euclidean,
462 ..Default::default()
463 },
464 dimension: 64,
465 centroids: vec![vec![0.0; 64]; 8],
466 };
467
468 let bytes = serialize_to_bytes(&snapshot).unwrap();
470 let restored: IvfTrainingSnapshot = deserialize_from_bytes(&bytes).unwrap();
471
472 assert_eq!(restored.config.n_clusters, 8);
473 assert_eq!(restored.dimension, 64);
474 assert_eq!(restored.centroids.len(), 8);
475 }
476
477 #[test]
478 fn test_hnsw_config_snapshot() {
479 let snapshot = HnswConfigSnapshot {
480 config: HnswConfig::default(),
481 dimension: 128,
482 };
483
484 let bytes = serialize_to_bytes(&snapshot).unwrap();
486 let restored: HnswConfigSnapshot = deserialize_from_bytes(&bytes).unwrap();
487
488 assert_eq!(restored.config.m, 16);
489 assert_eq!(restored.dimension, 128);
490 }
491
492 #[test]
493 fn test_spfresh_training_snapshot() {
494 let snapshot = SpFreshTrainingSnapshot {
495 config: SpFreshConfig::default(),
496 dimension: 32,
497 centroids: vec![vec![1.0; 32]; 16],
498 };
499
500 let bytes = serialize_to_bytes(&snapshot).unwrap();
501 let restored: SpFreshTrainingSnapshot = deserialize_from_bytes(&bytes).unwrap();
502
503 assert_eq!(restored.dimension, 32);
504 assert_eq!(restored.centroids.len(), 16);
505 }
506
507 #[test]
508 fn test_hnsw_full_persistence() {
509 use crate::hnsw::HnswIndex;
510
511 let index = HnswIndex::new();
513
514 for i in 0..50 {
516 let vector: Vec<f32> = (0..64).map(|j| ((i + j) as f32 * 0.1).sin()).collect();
517 index.insert(format!("vec_{}", i), vector);
518 }
519
520 assert_eq!(index.len(), 50);
521
522 let bytes = index.to_bytes().unwrap();
524 assert!(!bytes.is_empty());
525
526 let restored = HnswIndex::from_bytes(&bytes).unwrap();
528
529 assert_eq!(restored.len(), 50);
531 assert_eq!(restored.dimension(), index.dimension());
532 assert_eq!(restored.max_level(), index.max_level());
533
534 let query: Vec<f32> = (0..64).map(|j| (j as f32 * 0.1).sin()).collect();
536 let original_results = index.search(&query, 5);
537 let restored_results = restored.search(&query, 5);
538
539 assert_eq!(original_results.len(), restored_results.len());
541 for (orig, rest) in original_results.iter().zip(restored_results.iter()) {
542 assert_eq!(orig.0, rest.0); assert!((orig.1 - rest.1).abs() < 1e-6); }
545 }
546
547 #[test]
548 fn test_hnsw_empty_persistence() {
549 use crate::hnsw::HnswIndex;
550
551 let index = HnswIndex::new();
552
553 let bytes = index.to_bytes().unwrap();
555
556 let restored = HnswIndex::from_bytes(&bytes).unwrap();
558
559 assert_eq!(restored.len(), 0);
560 assert!(restored.is_empty());
561 }
562
563 #[test]
564 fn test_ivf_full_persistence() {
565 use crate::ivf::{IvfConfig, IvfIndex};
566
567 let training_vectors: Vec<Vec<f32>> = (0..100)
569 .map(|i| (0..32).map(|j| ((i + j) as f32 * 0.1).sin()).collect())
570 .collect();
571
572 let mut index = IvfIndex::new(IvfConfig {
573 n_clusters: 10,
574 n_probe: 3,
575 ..Default::default()
576 });
577
578 index.train(&training_vectors).unwrap();
579 assert!(index.is_trained());
580
581 for (i, v) in training_vectors.iter().enumerate() {
583 index.add(format!("vec_{}", i), v.clone()).unwrap();
584 }
585
586 assert_eq!(index.len(), 100);
587
588 let bytes = index.to_bytes().unwrap();
590 assert!(!bytes.is_empty());
591
592 let restored = IvfIndex::from_bytes(&bytes).unwrap();
594
595 assert_eq!(restored.len(), 100);
597 assert!(restored.is_trained());
598 assert_eq!(restored.n_clusters(), 10);
599
600 let query = &training_vectors[0];
602 let original_results = index.search(query, 5).unwrap();
603 let restored_results = restored.search(query, 5).unwrap();
604
605 assert_eq!(original_results[0].id, restored_results[0].id);
607 assert_eq!(original_results[0].id, "vec_0");
608 }
609
610 #[test]
611 fn test_ivf_empty_persistence() {
612 use crate::ivf::{IvfConfig, IvfIndex};
613
614 let index = IvfIndex::new(IvfConfig::default());
616
617 let bytes = index.to_bytes().unwrap();
619
620 let restored = IvfIndex::from_bytes(&bytes).unwrap();
622
623 assert_eq!(restored.len(), 0);
624 assert!(!restored.is_trained());
625 }
626
627 #[test]
628 fn test_spfresh_full_persistence() {
629 use crate::spfresh::{SpFreshConfig, SpFreshIndex};
630 use common::Vector;
631
632 let training_vectors: Vec<Vector> = (0..100)
634 .map(|i| Vector {
635 id: format!("vec_{}", i),
636 values: (0..32).map(|j| ((i + j) as f32 * 0.1).sin()).collect(),
637 metadata: None,
638 ttl_seconds: None,
639 expires_at: None,
640 })
641 .collect();
642
643 let index = SpFreshIndex::new(SpFreshConfig {
645 num_clusters: 4,
646 n_probe: 2,
647 ..Default::default()
648 });
649
650 index.train(&training_vectors).unwrap();
651 assert!(index.is_trained());
652
653 let stats = index.stats();
654 assert_eq!(stats.total_vectors, 100);
655 assert_eq!(stats.num_clusters, 4);
656
657 let bytes = index.to_bytes().unwrap();
659 assert!(!bytes.is_empty());
660
661 let restored = SpFreshIndex::from_bytes(&bytes).unwrap();
663
664 let restored_stats = restored.stats();
666 assert!(restored.is_trained());
667 assert_eq!(restored_stats.total_vectors, 100);
668 assert_eq!(restored_stats.num_clusters, 4);
669 assert_eq!(restored_stats.dimension, Some(32));
670
671 let query = &training_vectors[50].values;
673 let original_results = index.search(query, 10).unwrap();
674 let restored_results = restored.search(query, 10).unwrap();
675
676 assert_eq!(original_results.len(), restored_results.len());
678
679 let original_ids: std::collections::HashSet<_> =
682 original_results.iter().map(|r| &r.id).collect();
683 let restored_ids: std::collections::HashSet<_> =
684 restored_results.iter().map(|r| &r.id).collect();
685 let overlap = original_ids.intersection(&restored_ids).count();
686 assert!(
687 overlap >= 8,
688 "Expected at least 80% overlap in top-10 results, got {}/10",
689 overlap
690 );
691 }
692
693 #[test]
694 fn test_spfresh_empty_persistence() {
695 use crate::spfresh::{SpFreshConfig, SpFreshIndex};
696
697 let index = SpFreshIndex::new(SpFreshConfig::default());
699
700 let bytes = index.to_bytes().unwrap();
702
703 let restored = SpFreshIndex::from_bytes(&bytes).unwrap();
705
706 let stats = restored.stats();
707 assert_eq!(stats.total_vectors, 0);
708 assert!(!restored.is_trained());
709 }
710
711 #[test]
712 fn test_spfresh_persistence_with_tombstones() {
713 use crate::spfresh::{SpFreshConfig, SpFreshIndex};
714 use common::Vector;
715
716 let training_vectors: Vec<Vector> = (0..50)
718 .map(|i| Vector {
719 id: format!("vec_{}", i),
720 values: (0..16).map(|j| ((i + j) as f32 * 0.1).cos()).collect(),
721 metadata: None,
722 ttl_seconds: None,
723 expires_at: None,
724 })
725 .collect();
726
727 let index = SpFreshIndex::new(SpFreshConfig {
729 num_clusters: 2,
730 ..Default::default()
731 });
732
733 index.train(&training_vectors).unwrap();
734
735 let ids_to_remove: Vec<String> = (0..10).map(|i| format!("vec_{}", i)).collect();
737 let removed = index.remove(&ids_to_remove);
738 assert_eq!(removed, 10);
739
740 let stats = index.stats();
741 assert_eq!(stats.total_vectors, 40);
742 assert_eq!(stats.total_tombstones, 10);
743
744 let bytes = index.to_bytes().unwrap();
746
747 let restored = SpFreshIndex::from_bytes(&bytes).unwrap();
749
750 let restored_stats = restored.stats();
752 assert_eq!(restored_stats.total_vectors, 40);
753 assert_eq!(restored_stats.total_tombstones, 10);
754
755 let results = restored.search(&training_vectors[0].values, 50).unwrap();
757 for result in &results {
758 let id_num: usize = result.id.strip_prefix("vec_").unwrap().parse().unwrap();
760 assert!(
761 id_num >= 10,
762 "Tombstoned vector {} appeared in results",
763 result.id
764 );
765 }
766 }
767}