Skip to main content

contextdb_vector/
store.rs

1use crate::HnswIndex;
2use contextdb_core::{MemoryAccountant, RowId, TxId, VectorEntry};
3use parking_lot::RwLock;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, OnceLock};
6
7pub struct VectorStore {
8    pub vectors: RwLock<Vec<VectorEntry>>,
9    pub dimension: RwLock<Option<usize>>,
10    pub hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
11    hnsw_bytes: AtomicUsize,
12}
13
14impl Default for VectorStore {
15    fn default() -> Self {
16        Self::new(Arc::new(OnceLock::new()))
17    }
18}
19
20impl VectorStore {
21    pub fn new(hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>) -> Self {
22        Self {
23            vectors: RwLock::new(Vec::new()),
24            dimension: RwLock::new(None),
25            hnsw,
26            hnsw_bytes: AtomicUsize::new(0),
27        }
28    }
29
30    pub fn apply_inserts(&self, inserts: Vec<VectorEntry>) {
31        {
32            let mut vectors = self.vectors.write();
33            for entry in &inserts {
34                vectors.push(entry.clone());
35            }
36        }
37
38        if let Some(rw_lock) = self.hnsw.get() {
39            let guard = rw_lock.write();
40            if let Some(hnsw) = guard.as_ref() {
41                for entry in &inserts {
42                    let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
43                        hnsw.insert(entry.row_id, &entry.vector);
44                    }));
45                }
46            }
47        }
48    }
49
50    pub fn apply_deletes(&self, deletes: Vec<(RowId, TxId)>) {
51        let mut vectors = self.vectors.write();
52        for (row_id, deleted_tx) in deletes {
53            for v in vectors.iter_mut() {
54                if v.row_id == row_id && v.deleted_tx.is_none() {
55                    v.deleted_tx = Some(deleted_tx);
56                }
57            }
58        }
59        drop(vectors);
60
61        if let Some(rw_lock) = self.hnsw.get() {
62            *rw_lock.write() = None;
63        }
64    }
65
66    pub fn clear_hnsw(&self, accountant: &MemoryAccountant) {
67        let bytes = self.hnsw_bytes.swap(0, Ordering::SeqCst);
68        if bytes > 0 {
69            accountant.release(bytes);
70        }
71        if let Some(rw_lock) = self.hnsw.get() {
72            *rw_lock.write() = None;
73        }
74    }
75
76    pub fn insert_loaded_vector(&self, entry: VectorEntry) {
77        let dimension = entry.vector.len();
78        let mut dim = self.dimension.write();
79        if dim.is_none() {
80            *dim = Some(dimension);
81        }
82        drop(dim);
83        self.vectors.write().push(entry);
84    }
85
86    pub fn set_dimension(&self, dimension: usize) {
87        let mut dim = self.dimension.write();
88        if dim.is_none() {
89            *dim = Some(dimension);
90        }
91    }
92
93    pub fn vector_count(&self) -> usize {
94        self.vectors.read().len()
95    }
96
97    pub fn all_entries(&self) -> Vec<VectorEntry> {
98        self.vectors.read().clone()
99    }
100
101    pub fn dimension(&self) -> Option<usize> {
102        *self.dimension.read()
103    }
104
105    pub fn has_hnsw_index(&self) -> bool {
106        self.hnsw
107            .get()
108            .is_some_and(|rw_lock| rw_lock.read().as_ref().is_some())
109    }
110
111    pub fn set_hnsw_bytes(&self, bytes: usize) {
112        self.hnsw_bytes.store(bytes, Ordering::SeqCst);
113    }
114}