1use cp_core::{CPError, Result};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::{Arc, RwLock};
10use tracing::{info, warn};
11use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
16pub struct IndexConfig {
17 pub dimensions: usize,
19 pub connectivity: usize,
21 pub ef_construction: usize,
23 pub ef_search: usize,
25 pub capacity: usize,
27}
28
29impl Default for IndexConfig {
30 fn default() -> Self {
31 Self {
32 dimensions: 1536,
33 connectivity: 16,
34 ef_construction: 200,
35 ef_search: 50,
36 capacity: 100_000,
37 }
38 }
39}
40
41pub struct PersistentHnswIndex {
43 index: Index,
45 index_path: Option<PathBuf>,
47 checkpoint_root: Option<[u8; 32]>,
49 uuid_to_key: HashMap<Uuid, u64>,
51 key_to_uuid: HashMap<u64, Uuid>,
53 next_key: u64,
55 needs_rebuild: bool,
57 config: IndexConfig,
59}
60
61impl PersistentHnswIndex {
62 pub fn new(config: IndexConfig) -> Result<Self> {
64 let options = IndexOptions {
65 dimensions: config.dimensions,
66 metric: MetricKind::Cos,
67 quantization: ScalarKind::F32,
68 connectivity: config.connectivity,
69 expansion_add: config.ef_construction,
70 expansion_search: config.ef_search,
71 multi: false,
72 };
73
74 let index = Index::new(&options)
75 .map_err(|e| CPError::Database(format!("Failed to create index: {}", e)))?;
76
77 index
78 .reserve(config.capacity)
79 .map_err(|e| CPError::Database(format!("Failed to reserve capacity: {}", e)))?;
80
81 Ok(Self {
82 index,
83 index_path: None,
84 checkpoint_root: None,
85 uuid_to_key: HashMap::new(),
86 key_to_uuid: HashMap::new(),
87 next_key: 0,
88 needs_rebuild: false,
89 config,
90 })
91 }
92
93 pub fn open(path: PathBuf, config: IndexConfig) -> Result<Self> {
95 let options = IndexOptions {
96 dimensions: config.dimensions,
97 metric: MetricKind::Cos,
98 quantization: ScalarKind::F32,
99 connectivity: config.connectivity,
100 expansion_add: config.ef_construction,
101 expansion_search: config.ef_search,
102 multi: false,
103 };
104
105 let index = Index::new(&options)
106 .map_err(|e| CPError::Database(format!("Failed to create index: {}", e)))?;
107
108 let mut inst = Self {
109 index,
110 index_path: Some(path.clone()),
111 checkpoint_root: None,
112 uuid_to_key: HashMap::new(),
113 key_to_uuid: HashMap::new(),
114 next_key: 0,
115 needs_rebuild: false,
116 config,
117 };
118
119 if path.exists() {
121 match inst.load() {
122 Ok(_) => info!("Loaded existing index from {:?}", path),
123 Err(e) => {
124 warn!("Failed to load index: {}, will rebuild", e);
125 inst.needs_rebuild = true;
126 }
127 }
128 }
129
130 Ok(inst)
131 }
132
133 fn load(&mut self) -> Result<()> {
135 if let Some(path) = self.index_path.clone() {
136 self.index
137 .load(path.to_str().ok_or_else(|| CPError::Database("Non-UTF8 index path".into()))?)
138 .map_err(|e| CPError::Database(format!("Failed to load index: {}", e)))?;
139
140 let map_path = path.with_extension("map");
142 if map_path.exists() {
143 let data = std::fs::read(&map_path)
144 .map_err(|e| CPError::Database(format!("Failed to read map: {}", e)))?;
145 self.load_mapping(&data)?;
146 }
147
148 let checkpoint_path = path.with_extension("checkpoint");
150 if checkpoint_path.exists() {
151 let data = std::fs::read(&checkpoint_path)
152 .map_err(|e| CPError::Database(format!("Failed to read checkpoint: {}", e)))?;
153 if data.len() == 32 {
154 let mut root = [0u8; 32];
155 root.copy_from_slice(&data);
156 self.checkpoint_root = Some(root);
157 }
158 }
159 }
160 Ok(())
161 }
162
163 pub fn save(&self) -> Result<()> {
165 if let Some(path) = &self.index_path {
166 self.index
167 .save(path.to_str().ok_or_else(|| CPError::Database("Non-UTF8 index path".into()))?)
168 .map_err(|e| CPError::Database(format!("Failed to save index: {}", e)))?;
169
170 let map_path = path.with_extension("map");
172 let map_data = self.serialize_mapping();
173 std::fs::write(&map_path, map_data)
174 .map_err(|e| CPError::Database(format!("Failed to write map: {}", e)))?;
175
176 if let Some(root) = &self.checkpoint_root {
178 let checkpoint_path = path.with_extension("checkpoint");
179 std::fs::write(&checkpoint_path, root)
180 .map_err(|e| CPError::Database(format!("Failed to write checkpoint: {}", e)))?;
181 }
182
183 info!("Saved index to {:?}", path);
184 }
185 Ok(())
186 }
187
188 fn serialize_mapping(&self) -> Vec<u8> {
190 let mut data = Vec::new();
191 data.extend_from_slice(&self.next_key.to_le_bytes());
193 let count = self.uuid_to_key.len() as u64;
195 data.extend_from_slice(&count.to_le_bytes());
196 for (uuid, key) in &self.uuid_to_key {
198 data.extend_from_slice(uuid.as_bytes());
199 data.extend_from_slice(&key.to_le_bytes());
200 }
201 data
202 }
203
204 fn load_mapping(&mut self, data: &[u8]) -> Result<()> {
206 if data.len() < 16 {
207 return Err(CPError::Database("Invalid mapping data".into()));
208 }
209
210 let next_key = u64::from_le_bytes(data[0..8].try_into().unwrap());
211 let count = u64::from_le_bytes(data[8..16].try_into().unwrap()) as usize;
212
213 self.next_key = next_key;
214 self.uuid_to_key.clear();
215 self.key_to_uuid.clear();
216
217 let mut offset = 16;
218 for _ in 0..count {
219 if offset + 24 > data.len() {
220 return Err(CPError::Database("Truncated mapping data".into()));
221 }
222 let uuid = Uuid::from_slice(&data[offset..offset + 16])
223 .map_err(|_| CPError::Database("Invalid UUID in mapping".into()))?;
224 let key = u64::from_le_bytes(data[offset + 16..offset + 24].try_into().unwrap());
225 self.uuid_to_key.insert(uuid, key);
226 self.key_to_uuid.insert(key, uuid);
227 offset += 24;
228 }
229
230 Ok(())
231 }
232
233 pub fn insert(&mut self, emb_id: Uuid, vector: Vec<f32>) -> Result<()> {
235 let key = self.next_key;
236 self.next_key += 1;
237
238 self.index
239 .add(key, &vector)
240 .map_err(|e| CPError::Database(format!("Failed to add vector: {}", e)))?;
241
242 self.uuid_to_key.insert(emb_id, key);
243 self.key_to_uuid.insert(key, emb_id);
244
245 Ok(())
246 }
247
248 pub fn search(&self, query: &[f32], k: usize) -> Vec<(Uuid, f32)> {
250 match self.index.search(query, k) {
251 Ok(results) => {
252 results
253 .keys
254 .iter()
255 .zip(results.distances.iter())
256 .filter_map(|(&key, &dist)| {
257 self.key_to_uuid.get(&key).map(|id| {
258 (*id, 1.0 - dist)
260 })
261 })
262 .collect()
263 }
264 Err(e) => {
265 warn!("Search failed: {}", e);
266 Vec::new()
267 }
268 }
269 }
270
271 pub fn is_valid(&self, current_root: &[u8; 32]) -> bool {
273 match &self.checkpoint_root {
274 Some(root) => root == current_root,
275 None => false,
276 }
277 }
278
279 pub fn checkpoint(&mut self, state_root: [u8; 32]) -> Result<()> {
281 self.checkpoint_root = Some(state_root);
282 self.needs_rebuild = false;
283 self.save()
284 }
285
286 pub fn invalidate(&mut self) {
288 self.needs_rebuild = true;
289 }
290
291 pub fn needs_rebuild(&self) -> bool {
293 self.needs_rebuild
294 }
295
296 pub fn clear(&mut self) -> Result<()> {
298 let options = IndexOptions {
300 dimensions: self.config.dimensions,
301 metric: MetricKind::Cos,
302 quantization: ScalarKind::F32,
303 connectivity: self.config.connectivity,
304 expansion_add: self.config.ef_construction,
305 expansion_search: self.config.ef_search,
306 multi: false,
307 };
308
309 self.index = Index::new(&options)
310 .map_err(|e| CPError::Database(format!("Failed to recreate index: {}", e)))?;
311
312 self.index
313 .reserve(self.config.capacity)
314 .map_err(|e| CPError::Database(format!("Failed to reserve capacity: {}", e)))?;
315
316 self.uuid_to_key.clear();
317 self.key_to_uuid.clear();
318 self.next_key = 0;
319 self.checkpoint_root = None;
320 self.needs_rebuild = false;
321 Ok(())
322 }
323
324 pub fn len(&self) -> usize {
326 self.index.size()
327 }
328
329 pub fn is_empty(&self) -> bool {
331 self.len() == 0
332 }
333}
334
335pub struct SharedPersistentIndex {
337 inner: Arc<RwLock<PersistentHnswIndex>>,
338}
339
340impl SharedPersistentIndex {
341 pub fn new(config: IndexConfig) -> Result<Self> {
342 Ok(Self {
343 inner: Arc::new(RwLock::new(PersistentHnswIndex::new(config)?)),
344 })
345 }
346
347 pub fn open(path: PathBuf, config: IndexConfig) -> Result<Self> {
348 Ok(Self {
349 inner: Arc::new(RwLock::new(PersistentHnswIndex::open(path, config)?)),
350 })
351 }
352
353 pub fn insert(&self, emb_id: Uuid, vector: Vec<f32>) -> Result<()> {
354 let mut index = self.inner.write().expect("hnsw index lock poisoned");
355 index.insert(emb_id, vector)
356 }
357
358 pub fn search(&self, query: &[f32], k: usize) -> Vec<(Uuid, f32)> {
359 let index = self.inner.read().expect("hnsw index lock poisoned");
360 index.search(query, k)
361 }
362
363 pub fn save(&self) -> Result<()> {
364 let index = self.inner.read().expect("hnsw index lock poisoned");
365 index.save()
366 }
367
368 pub fn checkpoint(&self, state_root: [u8; 32]) -> Result<()> {
369 let mut index = self.inner.write().expect("hnsw index lock poisoned");
370 index.checkpoint(state_root)
371 }
372
373 pub fn is_valid(&self, current_root: &[u8; 32]) -> bool {
374 let index = self.inner.read().expect("hnsw index lock poisoned");
375 index.is_valid(current_root)
376 }
377
378 pub fn invalidate(&self) {
379 let mut index = self.inner.write().expect("hnsw index lock poisoned");
380 index.invalidate();
381 }
382
383 pub fn needs_rebuild(&self) -> bool {
384 let index = self.inner.read().expect("hnsw index lock poisoned");
385 index.needs_rebuild()
386 }
387
388 pub fn clear(&self) -> Result<()> {
389 let mut index = self.inner.write().expect("hnsw index lock poisoned");
390 index.clear()
391 }
392
393 pub fn len(&self) -> usize {
394 let index = self.inner.read().expect("hnsw index lock poisoned");
395 index.len()
396 }
397
398 pub fn is_empty(&self) -> bool {
399 self.len() == 0
400 }
401}
402
403impl Clone for SharedPersistentIndex {
404 fn clone(&self) -> Self {
405 Self {
406 inner: Arc::clone(&self.inner),
407 }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_in_memory_index() {
417 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
418
419 let id1 = Uuid::new_v4();
420 let id2 = Uuid::new_v4();
421
422 let v1: Vec<f32> = (0..1536).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
424 let v2: Vec<f32> = (0..1536).map(|i| if i == 1 { 1.0 } else { 0.0 }).collect();
425
426 index.insert(id1, v1.clone()).unwrap();
427 index.insert(id2, v2).unwrap();
428
429 let results = index.search(&v1, 2);
430 assert_eq!(results.len(), 2);
431 assert_eq!(results[0].0, id1); }
433
434 #[test]
439 fn test_index_invalidation() {
440 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
441
442 assert!(!index.needs_rebuild());
443
444 index.invalidate();
445 assert!(index.needs_rebuild());
446
447 index.checkpoint([1u8; 32]).unwrap();
448 assert!(!index.needs_rebuild());
449 }
450
451 #[test]
454 fn test_hnsw_index_new() {
455 let index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
456 assert!(index.is_empty());
457 assert_eq!(index.len(), 0);
458 }
459
460 #[test]
461 fn test_hnsw_index_add_vector() {
462 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
463
464 let id = Uuid::new_v4();
465 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
466
467 index.insert(id, vector.clone()).unwrap();
468
469 assert_eq!(index.len(), 1);
470 }
471
472 #[test]
473 fn test_hnsw_index_search() {
474 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
475
476 let id1 = Uuid::new_v4();
477 let id2 = Uuid::new_v4();
478 let id3 = Uuid::new_v4();
479
480 let v1: Vec<f32> = (0..1536).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
482 let v2: Vec<f32> = (0..1536).map(|i| if i == 1 { 1.0 } else { 0.0 }).collect();
483 let v3: Vec<f32> = (0..1536).map(|i| if i == 2 { 1.0 } else { 0.0 }).collect();
484
485 index.insert(id1, v1.clone()).unwrap();
486 index.insert(id2, v2).unwrap();
487 index.insert(id3, v3).unwrap();
488
489 let results = index.search(&v1, 3);
491 assert!(results.len() > 0);
492 assert_eq!(results[0].0, id1);
493 }
494
495 #[test]
496 fn test_hnsw_index_search_k_results() {
497 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
498
499 for i in 0..5 {
501 let id = Uuid::new_v4();
502 let vector: Vec<f32> = (0..1536).map(|j| (i * j) as f32 * 0.001).collect();
503 index.insert(id, vector).unwrap();
504 }
505
506 let query: Vec<f32> = (0..1536).map(|i| i as f32 * 0.001).collect();
508 let results = index.search(&query, 3);
509 assert!(results.len() <= 3);
510 }
511
512 #[test]
513 fn test_hnsw_index_search_empty_query() {
514 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
515
516 let id = Uuid::new_v4();
517 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
518 index.insert(id, vector).unwrap();
519
520 let empty_query: Vec<f32> = vec![];
522 let _results = index.search(&empty_query, 5);
523 }
525
526 #[test]
527 fn test_hnsw_index_delete_vector() {
528 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
529
530 let id = Uuid::new_v4();
531 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
532
533 index.insert(id, vector.clone()).unwrap();
534 assert_eq!(index.len(), 1);
535
536 index.clear().unwrap();
538 assert_eq!(index.len(), 0);
539 }
540
541 #[test]
549 fn test_hnsw_index_rebuild_from_sqlite() {
550 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
553
554 for i in 0..5 {
556 let id = Uuid::new_v4();
557 let vector: Vec<f32> = (0..1536).map(|j| (i * j) as f32 * 0.001).collect();
558 index.insert(id, vector).unwrap();
559 }
560
561 assert_eq!(index.len(), 5);
562
563 index.clear().unwrap();
565 assert_eq!(index.len(), 0);
566
567 for i in 0..5 {
569 let id = Uuid::new_v4();
570 let vector: Vec<f32> = (0..1536).map(|j| (i * j) as f32 * 0.001).collect();
571 index.insert(id, vector).unwrap();
572 }
573
574 assert_eq!(index.len(), 5);
575 }
576
577 #[test]
578 fn test_hnsw_index_consistency_with_sqlite() {
579 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
581
582 let initial_len = index.len();
583 assert_eq!(initial_len, 0);
584
585 let id = Uuid::new_v4();
587 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
588 index.insert(id, vector).unwrap();
589
590 assert_eq!(index.len(), 1);
591 }
592
593 #[test]
594 fn test_hnsw_index_cosine_similarity() {
595 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
596
597 let id1 = Uuid::new_v4();
599 let v1: Vec<f32> = (0..1536).map(|i| 0.5_f32).collect();
600
601 let id2 = Uuid::new_v4();
602 let v2: Vec<f32> = (0..1536).map(|i| 0.5_f32).collect();
603
604 let id3 = Uuid::new_v4();
605 let v3: Vec<f32> = (0..1536).map(|i| -0.5_f32).collect();
606
607 index.insert(id1, v1.clone()).unwrap();
608 index.insert(id2, v2).unwrap();
609 index.insert(id3, v3).unwrap();
610
611 let results = index.search(&v1, 3);
613
614 if results.len() >= 3 {
617 assert!(results[0].1 >= results[1].1);
619 }
620 }
621
622 #[test]
623 fn test_hnsw_index_empty_index_search() {
624 let index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
625
626 let query: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
627 let results = index.search(&query, 5);
628
629 assert_eq!(results.len(), 0);
631 }
632
633 #[test]
634 fn test_hnsw_index_batch_add() {
635 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
636
637 for batch_idx in 0..10 {
639 let id = Uuid::new_v4();
640 let vector: Vec<f32> = (0..1536).map(|i| (batch_idx * i) as f32 * 0.001).collect();
641 index.insert(id, vector).unwrap();
642 }
643
644 assert_eq!(index.len(), 10);
645 }
646
647 #[test]
648 fn test_hnsw_index_m_configuration() {
649 let config = IndexConfig::default();
650
651 assert_eq!(config.connectivity, 16);
653 }
654
655 #[test]
656 fn test_hnsw_index_ef_configuration() {
657 let config = IndexConfig::default();
658
659 assert_eq!(config.ef_construction, 200);
661 }
662
663 #[test]
664 fn test_hnsw_index_is_valid() {
665 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
666
667 let test_root = [1u8; 32];
669 assert!(!index.is_valid(&test_root));
670
671 index.checkpoint(test_root).unwrap();
673 assert!(index.is_valid(&test_root));
674
675 let different_root = [2u8; 32];
677 assert!(!index.is_valid(&different_root));
678 }
679
680 #[test]
681 fn test_hnsw_index_checkpoint() {
682 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
683
684 let root = [1u8; 32];
685 index.checkpoint(root).unwrap();
686
687 assert!(index.is_valid(&root));
689 }
690
691 #[test]
692 fn test_shared_persistent_index() {
693 let index = SharedPersistentIndex::new(IndexConfig::default()).unwrap();
694
695 assert!(index.is_empty());
697
698 let id = Uuid::new_v4();
700 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
701 index.insert(id, vector.clone()).unwrap();
702
703 assert_eq!(index.len(), 1);
705
706 let results = index.search(&vector, 5);
708 assert!(results.len() > 0);
709 }
710
711 #[test]
712 fn test_shared_persistent_index_clone() {
713 let index1 = SharedPersistentIndex::new(IndexConfig::default()).unwrap();
714
715 let index2 = index1.clone();
717
718 let id = Uuid::new_v4();
719 let vector: Vec<f32> = (0..1536).map(|i| i as f32 * 0.01).collect();
720
721 index1.insert(id, vector.clone()).unwrap();
723
724 assert_eq!(index2.len(), 1);
726 }
727
728 #[test]
729 fn test_index_clear() {
730 let mut index = PersistentHnswIndex::new(IndexConfig::default()).unwrap();
731
732 for i in 0..5 {
734 let id = Uuid::new_v4();
735 let vector: Vec<f32> = (0..1536).map(|j| i as f32 * 0.01).collect();
736 index.insert(id, vector).unwrap();
737 }
738
739 assert_eq!(index.len(), 5);
740
741 index.clear().unwrap();
743
744 assert_eq!(index.len(), 0);
745 }
746}