Skip to main content

contextdb_vector/
store.rs

1use crate::{HnswIndex, quantized::StoredVectorEntry};
2use contextdb_core::{
3    Error, MemoryAccountant, Result, RowId, TxId, VectorEntry, VectorIndexRef, VectorQuantization,
4};
5use parking_lot::{Mutex, RwLock};
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::{Arc, OnceLock};
9
10pub struct IndexState {
11    dimension: usize,
12    quantization: VectorQuantization,
13    vectors: RwLock<Vec<StoredVectorEntry>>,
14    hnsw: OnceLock<RwLock<Option<HnswIndex>>>,
15    hnsw_bytes: AtomicUsize,
16}
17
18impl IndexState {
19    fn new(dimension: usize, quantization: VectorQuantization) -> Self {
20        Self {
21            dimension,
22            quantization,
23            vectors: RwLock::new(Vec::new()),
24            hnsw: OnceLock::new(),
25            hnsw_bytes: AtomicUsize::new(0),
26        }
27    }
28
29    pub fn dimension(&self) -> usize {
30        self.dimension
31    }
32
33    pub fn quantization(&self) -> VectorQuantization {
34        self.quantization
35    }
36
37    pub fn vector_count(&self) -> usize {
38        self.vectors
39            .read()
40            .iter()
41            .filter(|entry| entry.deleted_tx.is_none())
42            .count()
43    }
44
45    pub fn byte_count(&self) -> usize {
46        let payload_bytes = self
47            .vectors
48            .read()
49            .iter()
50            .filter(|entry| entry.deleted_tx.is_none())
51            .map(StoredVectorEntry::estimated_bytes)
52            .sum::<usize>();
53        payload_bytes.saturating_add(self.hnsw_bytes.load(Ordering::SeqCst))
54    }
55
56    pub fn all_entries(&self, index: &VectorIndexRef) -> Vec<VectorEntry> {
57        self.vectors
58            .read()
59            .iter()
60            .map(|entry| entry.to_vector_entry(index.clone()))
61            .collect()
62    }
63
64    pub fn find_by_row_id(&self, index: &VectorIndexRef, row_id: RowId) -> Option<VectorEntry> {
65        self.vectors
66            .read()
67            .iter()
68            .rev()
69            .find(|entry| entry.row_id == row_id)
70            .map(|entry| entry.to_vector_entry(index.clone()))
71    }
72
73    fn stored_by_row_id(&self, row_id: RowId) -> Option<StoredVectorEntry> {
74        self.vectors
75            .read()
76            .iter()
77            .rev()
78            .find(|entry| entry.row_id == row_id)
79            .cloned()
80    }
81
82    pub(crate) fn with_entries<R>(&self, f: impl FnOnce(&[StoredVectorEntry]) -> R) -> R {
83        let entries = self.vectors.read();
84        f(&entries)
85    }
86
87    pub fn entry_count(&self) -> usize {
88        self.vectors.read().len()
89    }
90
91    fn stored_entry(&self, entry: VectorEntry) -> StoredVectorEntry {
92        StoredVectorEntry::from_vector_entry(entry, self.quantization)
93    }
94
95    fn push_entry(&self, entry: StoredVectorEntry) {
96        self.vectors.write().push(entry);
97    }
98
99    fn tombstone_row(&self, row_id: RowId, deleted_tx: TxId) -> usize {
100        let mut released = 0usize;
101        let mut vectors = self.vectors.write();
102        for entry in vectors.iter_mut() {
103            if entry.row_id == row_id && entry.deleted_tx.is_none() {
104                released = released.saturating_add(entry.estimated_bytes());
105                entry.deleted_tx = Some(deleted_tx);
106            }
107        }
108        released
109    }
110
111    pub fn clear_hnsw(&self, accountant: &MemoryAccountant) {
112        let bytes = self.hnsw_bytes.swap(0, Ordering::SeqCst);
113        if bytes > 0 {
114            accountant.release(bytes);
115        }
116        if let Some(lock) = self.hnsw.get() {
117            *lock.write() = None;
118        }
119    }
120
121    pub fn hnsw_len(&self) -> Option<usize> {
122        self.hnsw
123            .get()
124            .and_then(|lock| lock.read().as_ref().map(|hnsw| hnsw.len()))
125    }
126
127    pub fn hnsw_stats(&self) -> Option<crate::HnswGraphStats> {
128        self.hnsw
129            .get()
130            .and_then(|lock| lock.read().as_ref().map(|hnsw| hnsw.graph_stats()))
131    }
132
133    pub fn set_hnsw(&self, hnsw: Option<HnswIndex>, bytes: usize) {
134        if hnsw.is_some() {
135            self.hnsw_bytes.store(bytes, Ordering::SeqCst);
136        }
137        let lock = self.hnsw.get_or_init(|| RwLock::new(None));
138        *lock.write() = hnsw;
139    }
140
141    pub fn set_hnsw_bytes(&self, bytes: usize) {
142        self.hnsw_bytes.store(bytes, Ordering::SeqCst);
143    }
144
145    pub fn hnsw(&self) -> &OnceLock<RwLock<Option<HnswIndex>>> {
146        &self.hnsw
147    }
148
149    pub fn storage_bytes_per_entry(&self) -> Vec<usize> {
150        self.vectors
151            .read()
152            .iter()
153            .map(StoredVectorEntry::estimated_bytes)
154            .collect()
155    }
156}
157
158pub struct VectorIndexInfo {
159    pub index: VectorIndexRef,
160    pub dimension: usize,
161    pub quantization: VectorQuantization,
162    pub vector_count: usize,
163    pub bytes: usize,
164}
165
166pub struct VectorStore {
167    registry: RwLock<HashMap<VectorIndexRef, Arc<IndexState>>>,
168    build_mutex: Mutex<()>,
169}
170
171impl Default for VectorStore {
172    fn default() -> Self {
173        Self::new(Arc::new(OnceLock::new()))
174    }
175}
176
177impl VectorStore {
178    pub fn new(_legacy_hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>) -> Self {
179        Self {
180            registry: RwLock::new(HashMap::new()),
181            build_mutex: Mutex::new(()),
182        }
183    }
184
185    pub fn register_index(
186        &self,
187        index: VectorIndexRef,
188        dimension: usize,
189        quantization: VectorQuantization,
190    ) {
191        let mut registry = self.registry.write();
192        registry
193            .entry(index)
194            .or_insert_with(|| Arc::new(IndexState::new(dimension, quantization)));
195    }
196
197    pub fn register_or_reconfigure_empty_index(
198        &self,
199        index: VectorIndexRef,
200        dimension: usize,
201        quantization: VectorQuantization,
202    ) {
203        let mut registry = self.registry.write();
204        match registry.get(&index) {
205            Some(state) if state.entry_count() != 0 => {}
206            Some(state) if state.dimension() == dimension => {}
207            Some(_) | None => {
208                registry.insert(index, Arc::new(IndexState::new(dimension, quantization)));
209            }
210        }
211    }
212
213    pub fn deregister_index(&self, index: &VectorIndexRef, accountant: &MemoryAccountant) {
214        if let Some(state) = self.registry.write().remove(index) {
215            state.clear_hnsw(accountant);
216        }
217    }
218
219    pub fn deregister_table(&self, table: &str, accountant: &MemoryAccountant) {
220        let removed = {
221            let mut registry = self.registry.write();
222            let keys = registry
223                .keys()
224                .filter(|index| index.table == table)
225                .cloned()
226                .collect::<Vec<_>>();
227            keys.into_iter()
228                .filter_map(|key| registry.remove(&key))
229                .collect::<Vec<_>>()
230        };
231        for state in removed {
232            state.clear_hnsw(accountant);
233        }
234    }
235
236    pub fn rename_index(&self, old: &VectorIndexRef, new: VectorIndexRef) -> Result<()> {
237        let mut registry = self.registry.write();
238        if registry.contains_key(&new) {
239            return Err(Error::Other(format!(
240                "vector index already exists: {}.{}",
241                new.table, new.column
242            )));
243        }
244        let state = registry
245            .remove(old)
246            .ok_or_else(|| Error::UnknownVectorIndex { index: old.clone() })?;
247        registry.insert(new, state);
248        Ok(())
249    }
250
251    pub fn state(&self, index: &VectorIndexRef) -> Result<Arc<IndexState>> {
252        self.registry
253            .read()
254            .get(index)
255            .cloned()
256            .ok_or_else(|| Error::UnknownVectorIndex {
257                index: index.clone(),
258            })
259    }
260
261    pub fn try_state(&self, index: &VectorIndexRef) -> Option<Arc<IndexState>> {
262        self.registry.read().get(index).cloned()
263    }
264
265    pub fn is_empty(&self) -> bool {
266        self.registry.read().is_empty()
267    }
268
269    pub fn index_count(&self) -> usize {
270        self.registry.read().len()
271    }
272
273    pub fn validate_vector(&self, index: &VectorIndexRef, actual: usize) -> Result<()> {
274        let state = self.state(index)?;
275        let expected = state.dimension();
276        if expected != actual {
277            return Err(Error::VectorIndexDimensionMismatch {
278                index: index.clone(),
279                expected,
280                actual,
281            });
282        }
283        Ok(())
284    }
285
286    pub fn apply_inserts(&self, inserts: Vec<VectorEntry>) {
287        for entry in inserts {
288            if let Some(state) = self.try_state(&entry.index) {
289                let stored_entry = state.stored_entry(entry);
290                let row_id = stored_entry.row_id;
291                state.push_entry(stored_entry.clone());
292                if let Some(lock) = state.hnsw().get() {
293                    let guard = lock.write();
294                    if let Some(hnsw) = guard.as_ref() {
295                        let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
296                            hnsw.insert(row_id, &stored_entry.vector);
297                        }));
298                    }
299                }
300            }
301        }
302    }
303
304    pub fn apply_deletes(&self, deletes: Vec<(VectorIndexRef, RowId, TxId)>) {
305        for (index, row_id, deleted_tx) in deletes {
306            if let Some(state) = self.try_state(&index) {
307                state.tombstone_row(row_id, deleted_tx);
308                if let Some(lock) = state.hnsw().get() {
309                    *lock.write() = None;
310                }
311            }
312        }
313    }
314
315    pub fn apply_moves(
316        &self,
317        moves: Vec<(VectorIndexRef, RowId, RowId, TxId)>,
318        lsn: contextdb_core::Lsn,
319    ) {
320        for (index, old_row_id, new_row_id, tx) in moves {
321            if let Some(state) = self.try_state(&index)
322                && let Some(old) = state.stored_by_row_id(old_row_id)
323                && old.deleted_tx.is_none()
324            {
325                state.tombstone_row(old_row_id, tx);
326                let mut moved = old;
327                moved.row_id = new_row_id;
328                moved.created_tx = tx;
329                moved.deleted_tx = None;
330                moved.lsn = lsn;
331                state.push_entry(moved.clone());
332                if let Some(lock) = state.hnsw().get() {
333                    let guard = lock.write();
334                    if let Some(hnsw) = guard.as_ref() {
335                        let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
336                            hnsw.insert(moved.row_id, &moved.vector);
337                        }));
338                    }
339                }
340            }
341        }
342    }
343
344    pub fn insert_loaded_vector(&self, entry: VectorEntry) {
345        let quantization = self
346            .try_state(&entry.index)
347            .map(|state| state.quantization())
348            .unwrap_or(VectorQuantization::F32);
349        self.register_or_reconfigure_empty_index(
350            entry.index.clone(),
351            entry.vector.len(),
352            quantization,
353        );
354        if let Some(state) = self.try_state(&entry.index) {
355            let stored_entry = state.stored_entry(entry);
356            state.push_entry(stored_entry);
357        }
358    }
359
360    pub fn all_entries(&self) -> Vec<VectorEntry> {
361        self.registry
362            .read()
363            .iter()
364            .flat_map(|(index, state)| state.all_entries(index))
365            .collect()
366    }
367
368    pub fn prune_row_ids(
369        &self,
370        row_ids: &std::collections::HashSet<RowId>,
371        accountant: &MemoryAccountant,
372    ) -> usize {
373        let mut released = 0usize;
374        for state in self.registry.read().values() {
375            let mut vectors = state.vectors.write();
376            vectors.retain(|entry| {
377                if row_ids.contains(&entry.row_id) {
378                    released = released.saturating_add(entry.estimated_bytes());
379                    false
380                } else {
381                    true
382                }
383            });
384            drop(vectors);
385            state.clear_hnsw(accountant);
386        }
387        released
388    }
389
390    pub fn entries_for_index(&self, index: &VectorIndexRef) -> Result<Vec<VectorEntry>> {
391        Ok(self.state(index)?.all_entries(index))
392    }
393
394    pub fn vector_count(&self) -> usize {
395        self.registry
396            .read()
397            .values()
398            .map(|state| state.vector_count())
399            .sum()
400    }
401
402    pub fn has_hnsw_index(&self) -> bool {
403        self.registry
404            .read()
405            .values()
406            .any(|state| state.hnsw_len().is_some())
407    }
408
409    pub fn has_hnsw_index_for(&self, index: &VectorIndexRef) -> bool {
410        self.try_state(index)
411            .and_then(|state| state.hnsw_len())
412            .is_some()
413    }
414
415    pub fn clear_hnsw(&self, accountant: &MemoryAccountant) {
416        for state in self.registry.read().values() {
417            state.clear_hnsw(accountant);
418        }
419    }
420
421    pub fn clear_hnsw_for(&self, index: &VectorIndexRef, accountant: &MemoryAccountant) {
422        if let Some(state) = self.try_state(index) {
423            state.clear_hnsw(accountant);
424        }
425    }
426
427    pub fn find_by_row_id(&self, row_id: RowId) -> Option<VectorEntry> {
428        self.registry
429            .read()
430            .iter()
431            .find_map(|(index, state)| state.find_by_row_id(index, row_id))
432    }
433
434    pub fn live_entry_for_row(
435        &self,
436        index: &VectorIndexRef,
437        row_id: RowId,
438        snapshot: contextdb_core::SnapshotId,
439    ) -> Option<VectorEntry> {
440        self.try_state(index).and_then(|state| {
441            state.with_entries(|entries| {
442                entries
443                    .iter()
444                    .rev()
445                    .find(|entry| entry.row_id == row_id && entry.visible_at(snapshot))
446                    .map(|entry| entry.to_vector_entry(index.clone()))
447            })
448        })
449    }
450
451    pub fn live_entries_for_row(
452        &self,
453        row_id: RowId,
454        snapshot: contextdb_core::SnapshotId,
455    ) -> Vec<VectorEntry> {
456        self.registry
457            .read()
458            .iter()
459            .flat_map(|(index, state)| state.all_entries(index))
460            .filter(|entry| entry.row_id == row_id && entry.visible_at(snapshot))
461            .collect()
462    }
463
464    pub fn vector_for_row_lsn(
465        &self,
466        index: &VectorIndexRef,
467        row_id: RowId,
468        lsn: contextdb_core::Lsn,
469    ) -> Option<Vec<f32>> {
470        self.try_state(index).and_then(|state| {
471            state.with_entries(|entries| {
472                entries
473                    .iter()
474                    .find(|entry| entry.row_id == row_id && entry.lsn == lsn)
475                    .map(|entry| entry.vector.to_f32())
476            })
477        })
478    }
479
480    pub fn storage_bytes_per_entry(&self, index: &VectorIndexRef) -> Result<Vec<usize>> {
481        Ok(self.state(index)?.storage_bytes_per_entry())
482    }
483
484    pub fn index_infos(&self) -> Vec<VectorIndexInfo> {
485        let mut infos = self
486            .registry
487            .read()
488            .iter()
489            .map(|(index, state)| VectorIndexInfo {
490                index: index.clone(),
491                dimension: state.dimension(),
492                quantization: state.quantization(),
493                vector_count: state.vector_count(),
494                bytes: state.byte_count(),
495            })
496            .collect::<Vec<_>>();
497        infos.sort_by(|a, b| {
498            a.index
499                .table
500                .cmp(&b.index.table)
501                .then(a.index.column.cmp(&b.index.column))
502        });
503        infos
504    }
505
506    pub fn build_lock(&self) -> parking_lot::MutexGuard<'_, ()> {
507        self.build_mutex.lock()
508    }
509}