1use cp_core::{CPError, Result};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::{Arc, RwLock};
10use tracing::{info, warn};
11use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
16pub struct IndexConfig {
17 pub dimensions: usize,
19 pub connectivity: usize,
21 pub ef_construction: usize,
23 pub ef_search: usize,
25 pub capacity: usize,
27}
28
29impl Default for IndexConfig {
30 fn default() -> Self {
31 Self {
32 dimensions: 1536,
33 connectivity: 16,
34 ef_construction: 200,
35 ef_search: 50,
36 capacity: 100_000,
37 }
38 }
39}
40
41pub struct PersistentHnswIndex {
43 index: Index,
45 index_path: Option<PathBuf>,
47 checkpoint_root: Option<[u8; 32]>,
49 uuid_to_key: HashMap<Uuid, u64>,
51 key_to_uuid: HashMap<u64, Uuid>,
53 next_key: u64,
55 needs_rebuild: bool,
57 config: IndexConfig,
59}
60
61impl PersistentHnswIndex {
62 pub fn new(config: IndexConfig) -> Result<Self> {
64 let options = IndexOptions {
65 dimensions: config.dimensions,
66 metric: MetricKind::Cos,
67 quantization: ScalarKind::F32,
68 connectivity: config.connectivity,
69 expansion_add: config.ef_construction,
70 expansion_search: config.ef_search,
71 multi: false,
72 };
73
74 let index = Index::new(&options)
75 .map_err(|e| CPError::Database(format!("Failed to create index: {e}")))?;
76
77 index
78 .reserve(config.capacity)
79 .map_err(|e| CPError::Database(format!("Failed to reserve capacity: {e}")))?;
80
81 Ok(Self {
82 index,
83 index_path: None,
84 checkpoint_root: None,
85 uuid_to_key: HashMap::new(),
86 key_to_uuid: HashMap::new(),
87 next_key: 0,
88 needs_rebuild: false,
89 config,
90 })
91 }
92
93 pub fn open(path: &std::path::Path, config: IndexConfig) -> Result<Self> {
95 let options = IndexOptions {
96 dimensions: config.dimensions,
97 metric: MetricKind::Cos,
98 quantization: ScalarKind::F32,
99 connectivity: config.connectivity,
100 expansion_add: config.ef_construction,
101 expansion_search: config.ef_search,
102 multi: false,
103 };
104
105 let index = Index::new(&options)
106 .map_err(|e| CPError::Database(format!("Failed to create index: {e}")))?;
107
108 index
109 .reserve(config.capacity)
110 .map_err(|e| CPError::Database(format!("Failed to reserve capacity: {e}")))?;
111
112 let mut inst = Self {
113 index,
114 index_path: Some(path.to_path_buf()),
115 checkpoint_root: None,
116 uuid_to_key: HashMap::new(),
117 key_to_uuid: HashMap::new(),
118 next_key: 0,
119 needs_rebuild: false,
120 config,
121 };
122
123 if path.exists() {
125 match inst.load() {
126 Ok(()) => info!("Loaded existing index from {:?}", path),
127 Err(e) => {
128 warn!("Failed to load index: {}, will rebuild", e);
129 inst.needs_rebuild = true;
130 }
131 }
132 }
133
134 Ok(inst)
135 }
136
137 fn load(&mut self) -> Result<()> {
139 if let Some(path) = self.index_path.clone() {
140 self.index
141 .load(
142 path.to_str()
143 .ok_or_else(|| CPError::Database("Non-UTF8 index path".into()))?,
144 )
145 .map_err(|e| CPError::Database(format!("Failed to load index: {e}")))?;
146
147 let map_path = path.with_extension("map");
149 if map_path.exists() {
150 let data = std::fs::read(&map_path)
151 .map_err(|e| CPError::Database(format!("Failed to read map: {e}")))?;
152 self.load_mapping(&data)?;
153 }
154
155 let checkpoint_path = path.with_extension("checkpoint");
157 if checkpoint_path.exists() {
158 let data = std::fs::read(&checkpoint_path)
159 .map_err(|e| CPError::Database(format!("Failed to read checkpoint: {e}")))?;
160 if data.len() == 32 {
161 let mut root = [0u8; 32];
162 root.copy_from_slice(&data);
163 self.checkpoint_root = Some(root);
164 }
165 }
166 }
167 Ok(())
168 }
169
170 pub fn save(&self) -> Result<()> {
172 if let Some(path) = &self.index_path {
173 self.index
174 .save(
175 path.to_str()
176 .ok_or_else(|| CPError::Database("Non-UTF8 index path".into()))?,
177 )
178 .map_err(|e| CPError::Database(format!("Failed to save index: {e}")))?;
179
180 let map_path = path.with_extension("map");
182 let map_data = self.serialize_mapping();
183 std::fs::write(&map_path, map_data)
184 .map_err(|e| CPError::Database(format!("Failed to write map: {e}")))?;
185
186 if let Some(root) = &self.checkpoint_root {
188 let checkpoint_path = path.with_extension("checkpoint");
189 std::fs::write(&checkpoint_path, root)
190 .map_err(|e| CPError::Database(format!("Failed to write checkpoint: {e}")))?;
191 }
192
193 info!("Saved index to {:?}", path);
194 }
195 Ok(())
196 }
197
198 fn serialize_mapping(&self) -> Vec<u8> {
200 let mut data = Vec::new();
201 data.extend_from_slice(&self.next_key.to_le_bytes());
203 let count = self.uuid_to_key.len() as u64;
205 data.extend_from_slice(&count.to_le_bytes());
206 for (uuid, key) in &self.uuid_to_key {
208 data.extend_from_slice(uuid.as_bytes());
209 data.extend_from_slice(&key.to_le_bytes());
210 }
211 data
212 }
213
214 fn load_mapping(&mut self, data: &[u8]) -> Result<()> {
216 if data.len() < 16 {
217 return Err(CPError::Database("Invalid mapping data".into()));
218 }
219
220 let next_key = u64::from_le_bytes(data[0..8].try_into().unwrap());
221 let count = u64::from_le_bytes(data[8..16].try_into().unwrap()) as usize;
222
223 self.next_key = next_key;
224 self.uuid_to_key.clear();
225 self.key_to_uuid.clear();
226
227 let mut offset = 16;
228 for _ in 0..count {
229 if offset + 24 > data.len() {
230 return Err(CPError::Database("Truncated mapping data".into()));
231 }
232 let uuid = Uuid::from_slice(&data[offset..offset + 16])
233 .map_err(|_| CPError::Database("Invalid UUID in mapping".into()))?;
234 let key = u64::from_le_bytes(data[offset + 16..offset + 24].try_into().unwrap());
235 self.uuid_to_key.insert(uuid, key);
236 self.key_to_uuid.insert(key, uuid);
237 offset += 24;
238 }
239
240 Ok(())
241 }
242
243 pub fn insert(&mut self, emb_id: Uuid, vector: &[f32]) -> Result<()> {
245 let key = self.next_key;
246 self.next_key += 1;
247
248 self.index
249 .add(key, vector)
250 .map_err(|e| CPError::Database(format!("Failed to add vector: {e}")))?;
251
252 self.uuid_to_key.insert(emb_id, key);
253 self.key_to_uuid.insert(key, emb_id);
254
255 Ok(())
256 }
257
258 pub fn search(&self, query: &[f32], k: usize) -> Vec<(Uuid, f32)> {
260 match self.index.search(query, k) {
261 Ok(results) => {
262 results
263 .keys
264 .iter()
265 .zip(results.distances.iter())
266 .filter_map(|(&key, &dist)| {
267 self.key_to_uuid.get(&key).map(|id| {
268 (*id, 1.0 - dist)
270 })
271 })
272 .collect()
273 }
274 Err(e) => {
275 warn!("Search failed: {}", e);
276 Vec::new()
277 }
278 }
279 }
280
281 pub fn is_valid(&self, current_root: &[u8; 32]) -> bool {
283 match &self.checkpoint_root {
284 Some(root) => root == current_root,
285 None => false,
286 }
287 }
288
289 pub fn checkpoint(&mut self, state_root: [u8; 32]) -> Result<()> {
291 self.checkpoint_root = Some(state_root);
292 self.needs_rebuild = false;
293 self.save()
294 }
295
296 pub fn invalidate(&mut self) {
298 self.needs_rebuild = true;
299 }
300
301 pub fn needs_rebuild(&self) -> bool {
303 self.needs_rebuild
304 }
305
306 pub fn clear(&mut self) -> Result<()> {
308 let options = IndexOptions {
310 dimensions: self.config.dimensions,
311 metric: MetricKind::Cos,
312 quantization: ScalarKind::F32,
313 connectivity: self.config.connectivity,
314 expansion_add: self.config.ef_construction,
315 expansion_search: self.config.ef_search,
316 multi: false,
317 };
318
319 self.index = Index::new(&options)
320 .map_err(|e| CPError::Database(format!("Failed to recreate index: {e}")))?;
321
322 self.index
323 .reserve(self.config.capacity)
324 .map_err(|e| CPError::Database(format!("Failed to reserve capacity: {e}")))?;
325
326 self.uuid_to_key.clear();
327 self.key_to_uuid.clear();
328 self.next_key = 0;
329 self.checkpoint_root = None;
330 self.needs_rebuild = false;
331 Ok(())
332 }
333
334 pub fn len(&self) -> usize {
336 self.index.size()
337 }
338
339 pub fn is_empty(&self) -> bool {
341 self.len() == 0
342 }
343}
344
345pub struct SharedPersistentIndex {
347 inner: Arc<RwLock<PersistentHnswIndex>>,
348}
349
350impl SharedPersistentIndex {
351 pub fn new(config: IndexConfig) -> Result<Self> {
352 Ok(Self {
353 inner: Arc::new(RwLock::new(PersistentHnswIndex::new(config)?)),
354 })
355 }
356
357 pub fn open(path: &std::path::Path, config: IndexConfig) -> Result<Self> {
358 Ok(Self {
359 inner: Arc::new(RwLock::new(PersistentHnswIndex::open(path, config)?)),
360 })
361 }
362
363 pub fn insert(&self, emb_id: Uuid, vector: &[f32]) -> Result<()> {
364 let mut index = self.inner.write().expect("hnsw index lock poisoned");
365 index.insert(emb_id, vector)
366 }
367
368 pub fn search(&self, query: &[f32], k: usize) -> Vec<(Uuid, f32)> {
369 let index = self.inner.read().expect("hnsw index lock poisoned");
370 index.search(query, k)
371 }
372
373 pub fn save(&self) -> Result<()> {
374 let index = self.inner.read().expect("hnsw index lock poisoned");
375 index.save()
376 }
377
378 pub fn checkpoint(&self, state_root: [u8; 32]) -> Result<()> {
379 let mut index = self.inner.write().expect("hnsw index lock poisoned");
380 index.checkpoint(state_root)
381 }
382
383 pub fn is_valid(&self, current_root: &[u8; 32]) -> bool {
384 let index = self.inner.read().expect("hnsw index lock poisoned");
385 index.is_valid(current_root)
386 }
387
388 pub fn invalidate(&self) {
389 let mut index = self.inner.write().expect("hnsw index lock poisoned");
390 index.invalidate();
391 }
392
393 pub fn needs_rebuild(&self) -> bool {
394 let index = self.inner.read().expect("hnsw index lock poisoned");
395 index.needs_rebuild()
396 }
397
398 pub fn clear(&self) -> Result<()> {
399 let mut index = self.inner.write().expect("hnsw index lock poisoned");
400 index.clear()
401 }
402
403 pub fn len(&self) -> usize {
404 let index = self.inner.read().expect("hnsw index lock poisoned");
405 index.len()
406 }
407
408 pub fn is_empty(&self) -> bool {
409 self.len() == 0
410 }
411}
412
413impl Clone for SharedPersistentIndex {
414 fn clone(&self) -> Self {
415 Self {
416 inner: Arc::clone(&self.inner),
417 }
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_in_memory_index() {
427 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
428
429 let id1 = Uuid::new_v4();
430 let id2 = Uuid::new_v4();
431
432 let v1: Vec<f32> = (0..1536).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
434 let v2: Vec<f32> = (0..1536).map(|i| if i == 1 { 1.0 } else { 0.0 }).collect();
435
436 index.insert(id1, &v1).unwrap();
437 index.insert(id2, &v2).unwrap();
438
439 let results = index.search(&v1, 2);
440 assert_eq!(results.len(), 2);
441 assert_eq!(results[0].0, id1); }
443
444 #[test]
449 fn test_index_invalidation() {
450 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
451
452 assert!(!index.needs_rebuild());
453
454 index.invalidate();
455 assert!(index.needs_rebuild());
456
457 index.checkpoint([1u8; 32]).unwrap();
458 assert!(!index.needs_rebuild());
459 }
460
461 #[test]
464 fn test_hnsw_index_new() {
465 let index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
466 assert!(index.is_empty());
467 assert_eq!(index.len(), 0);
468 }
469
470 #[test]
471 fn test_hnsw_index_add_vector() {
472 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
473
474 let id = Uuid::new_v4();
475 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
476
477 index.insert(id, &vector).unwrap();
478
479 assert_eq!(index.len(), 1);
480 }
481
482 #[test]
483 fn test_hnsw_index_search() {
484 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
485
486 let id1 = Uuid::new_v4();
487 let id2 = Uuid::new_v4();
488 let id3 = Uuid::new_v4();
489
490 let v1: Vec<f32> = (0..1536).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
492 let v2: Vec<f32> = (0..1536).map(|i| if i == 1 { 1.0 } else { 0.0 }).collect();
493 let v3: Vec<f32> = (0..1536).map(|i| if i == 2 { 1.0 } else { 0.0 }).collect();
494
495 index.insert(id1, &v1).unwrap();
496 index.insert(id2, &v2).unwrap();
497 index.insert(id3, &v3).unwrap();
498
499 let results = index.search(&v1, 3);
501 assert!(!results.is_empty());
502 assert_eq!(results[0].0, id1);
503 }
504
505 #[test]
506 fn test_hnsw_index_search_k_results() {
507 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
508
509 for i in 0..5 {
511 let id = Uuid::new_v4();
512 let vector: Vec<f32> = (0..1536).map(|j| (i * j) as f32 * 0.001).collect();
513 index.insert(id, &vector).unwrap();
514 }
515
516 let query: Vec<f32> = (0..1536).map(|i| i as f32 * 0.001).collect();
518 let results = index.search(&query, 3);
519 assert!(results.len() <= 3);
520 }
521
522 #[test]
523 fn test_hnsw_index_search_empty_query() {
524 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
525
526 let id = Uuid::new_v4();
527 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
528 index.insert(id, &vector).unwrap();
529
530 let empty_query: Vec<f32> = vec![];
532 let _results = index.search(&empty_query, 5);
533 }
535
536 #[test]
537 fn test_hnsw_index_delete_vector() {
538 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
539
540 let id = Uuid::new_v4();
541 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
542
543 index.insert(id, &vector).unwrap();
544 assert_eq!(index.len(), 1);
545
546 index.clear().unwrap();
548 assert_eq!(index.len(), 0);
549 }
550
551 #[test]
559 fn test_hnsw_index_rebuild_from_sqlite() {
560 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
563
564 for i in 0..5 {
566 let id = Uuid::new_v4();
567 let vector: Vec<f32> = (0..1536).map(|j| (i * j) as f32 * 0.001).collect();
568 index.insert(id, &vector).unwrap();
569 }
570
571 assert_eq!(index.len(), 5);
572
573 index.clear().unwrap();
575 assert_eq!(index.len(), 0);
576
577 for i in 0..5 {
579 let id = Uuid::new_v4();
580 let vector: Vec<f32> = (0..1536).map(|j| (i * j) as f32 * 0.001).collect();
581 index.insert(id, &vector).unwrap();
582 }
583
584 assert_eq!(index.len(), 5);
585 }
586
587 #[test]
588 fn test_hnsw_index_consistency_with_sqlite() {
589 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
591
592 let initial_len = index.len();
593 assert_eq!(initial_len, 0);
594
595 let id = Uuid::new_v4();
597 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
598 index.insert(id, &vector).unwrap();
599
600 assert_eq!(index.len(), 1);
601 }
602
603 #[test]
604 fn test_hnsw_index_cosine_similarity() {
605 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
606
607 let id1 = Uuid::new_v4();
609 let v1: Vec<f32> = (0..1536).map(|_| 0.5_f32).collect();
610
611 let id2 = Uuid::new_v4();
612 let v2: Vec<f32> = (0..1536).map(|_| 0.5_f32).collect();
613
614 let id3 = Uuid::new_v4();
615 let v3: Vec<f32> = (0..1536).map(|_| -0.5_f32).collect();
616
617 index.insert(id1, &v1).unwrap();
618 index.insert(id2, &v2).unwrap();
619 index.insert(id3, &v3).unwrap();
620
621 let results = index.search(&v1, 3);
623
624 if results.len() >= 3 {
627 assert!(results[0].1 >= results[1].1);
629 }
630 }
631
632 #[test]
633 fn test_hnsw_index_empty_index_search() {
634 let index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
635
636 let query: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
637 let results = index.search(&query, 5);
638
639 assert_eq!(results.len(), 0);
641 }
642
643 #[test]
644 fn test_hnsw_index_batch_add() {
645 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
646
647 for batch_idx in 0..10 {
649 let id = Uuid::new_v4();
650 let vector: Vec<f32> = (0..1536).map(|i| (batch_idx * i) as f32 * 0.001).collect();
651 index.insert(id, &vector).unwrap();
652 }
653
654 assert_eq!(index.len(), 10);
655 }
656
657 #[test]
658 fn test_hnsw_index_m_configuration() {
659 let config = IndexConfig::default();
660
661 assert_eq!(config.connectivity, 16);
663 }
664
665 #[test]
666 fn test_hnsw_index_ef_configuration() {
667 let config = IndexConfig::default();
668
669 assert_eq!(config.ef_construction, 200);
671 }
672
673 #[test]
674 fn test_hnsw_index_is_valid() {
675 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
676
677 let test_root = [1u8; 32];
679 assert!(!index.is_valid(&test_root));
680
681 index.checkpoint(test_root).unwrap();
683 assert!(index.is_valid(&test_root));
684
685 let different_root = [2u8; 32];
687 assert!(!index.is_valid(&different_root));
688 }
689
690 #[test]
691 fn test_hnsw_index_checkpoint() {
692 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
693
694 let root = [1u8; 32];
695 index.checkpoint(root).unwrap();
696
697 assert!(index.is_valid(&root));
699 }
700
701 #[test]
702 fn test_shared_persistent_index() {
703 let index = SharedPersistentIndex::new(IndexConfig::default()).unwrap();
704
705 assert!(index.is_empty());
707
708 let id = Uuid::new_v4();
710 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
711 index.insert(id, &vector).unwrap();
712
713 assert_eq!(index.len(), 1);
715
716 let results = index.search(&vector, 5);
718 assert!(!results.is_empty());
719 }
720
721 #[test]
722 fn test_shared_persistent_index_clone() {
723 let index1 = SharedPersistentIndex::new(IndexConfig::default()).unwrap();
724
725 let index2 = index1.clone();
727
728 let id = Uuid::new_v4();
729 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
730
731 index1.insert(id, &vector).unwrap();
733
734 assert_eq!(index2.len(), 1);
736 }
737
738 #[test]
739 fn test_index_clear() {
740 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
741
742 for i in 0..5 {
744 let id = Uuid::new_v4();
745 let vector: Vec<f32> = (0..1536).map(|_| i as f32 * 0.01).collect();
746 index.insert(id, &vector).unwrap();
747 }
748
749 assert_eq!(index.len(), 5);
750
751 index.clear().unwrap();
753
754 assert_eq!(index.len(), 0);
755 }
756}