1use crate::{HnswIndex, quantized::StoredVectorEntry};
2use contextdb_core::{
3 Error, MemoryAccountant, Result, RowId, TxId, VectorEntry, VectorIndexRef, VectorQuantization,
4};
5use parking_lot::{Mutex, RwLock};
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::{Arc, OnceLock};
9
10pub struct IndexState {
11 dimension: usize,
12 quantization: VectorQuantization,
13 vectors: RwLock<Vec<StoredVectorEntry>>,
14 hnsw: OnceLock<RwLock<Option<HnswIndex>>>,
15 hnsw_bytes: AtomicUsize,
16}
17
18impl IndexState {
19 fn new(dimension: usize, quantization: VectorQuantization) -> Self {
20 Self {
21 dimension,
22 quantization,
23 vectors: RwLock::new(Vec::new()),
24 hnsw: OnceLock::new(),
25 hnsw_bytes: AtomicUsize::new(0),
26 }
27 }
28
29 pub fn dimension(&self) -> usize {
30 self.dimension
31 }
32
33 pub fn quantization(&self) -> VectorQuantization {
34 self.quantization
35 }
36
37 pub fn vector_count(&self) -> usize {
38 self.vectors
39 .read()
40 .iter()
41 .filter(|entry| entry.deleted_tx.is_none())
42 .count()
43 }
44
45 pub fn byte_count(&self) -> usize {
46 let payload_bytes = self
47 .vectors
48 .read()
49 .iter()
50 .filter(|entry| entry.deleted_tx.is_none())
51 .map(StoredVectorEntry::estimated_bytes)
52 .sum::<usize>();
53 payload_bytes.saturating_add(self.hnsw_bytes.load(Ordering::SeqCst))
54 }
55
56 pub fn all_entries(&self, index: &VectorIndexRef) -> Vec<VectorEntry> {
57 self.vectors
58 .read()
59 .iter()
60 .map(|entry| entry.to_vector_entry(index.clone()))
61 .collect()
62 }
63
64 pub fn find_by_row_id(&self, index: &VectorIndexRef, row_id: RowId) -> Option<VectorEntry> {
65 self.vectors
66 .read()
67 .iter()
68 .rev()
69 .find(|entry| entry.row_id == row_id)
70 .map(|entry| entry.to_vector_entry(index.clone()))
71 }
72
73 fn stored_by_row_id(&self, row_id: RowId) -> Option<StoredVectorEntry> {
74 self.vectors
75 .read()
76 .iter()
77 .rev()
78 .find(|entry| entry.row_id == row_id)
79 .cloned()
80 }
81
82 pub(crate) fn with_entries<R>(&self, f: impl FnOnce(&[StoredVectorEntry]) -> R) -> R {
83 let entries = self.vectors.read();
84 f(&entries)
85 }
86
87 pub fn entry_count(&self) -> usize {
88 self.vectors.read().len()
89 }
90
91 fn stored_entry(&self, entry: VectorEntry) -> StoredVectorEntry {
92 StoredVectorEntry::from_vector_entry(entry, self.quantization)
93 }
94
95 fn push_entry(&self, entry: StoredVectorEntry) {
96 self.vectors.write().push(entry);
97 }
98
99 fn tombstone_row(&self, row_id: RowId, deleted_tx: TxId) -> usize {
100 let mut released = 0usize;
101 let mut vectors = self.vectors.write();
102 for entry in vectors.iter_mut() {
103 if entry.row_id == row_id && entry.deleted_tx.is_none() {
104 released = released.saturating_add(entry.estimated_bytes());
105 entry.deleted_tx = Some(deleted_tx);
106 }
107 }
108 released
109 }
110
111 pub fn clear_hnsw(&self, accountant: &MemoryAccountant) {
112 let bytes = self.hnsw_bytes.swap(0, Ordering::SeqCst);
113 if bytes > 0 {
114 accountant.release(bytes);
115 }
116 if let Some(lock) = self.hnsw.get() {
117 *lock.write() = None;
118 }
119 }
120
121 pub fn hnsw_len(&self) -> Option<usize> {
122 self.hnsw
123 .get()
124 .and_then(|lock| lock.read().as_ref().map(|hnsw| hnsw.len()))
125 }
126
127 pub fn hnsw_stats(&self) -> Option<crate::HnswGraphStats> {
128 self.hnsw
129 .get()
130 .and_then(|lock| lock.read().as_ref().map(|hnsw| hnsw.graph_stats()))
131 }
132
133 pub fn set_hnsw(&self, hnsw: Option<HnswIndex>, bytes: usize) {
134 if hnsw.is_some() {
135 self.hnsw_bytes.store(bytes, Ordering::SeqCst);
136 }
137 let lock = self.hnsw.get_or_init(|| RwLock::new(None));
138 *lock.write() = hnsw;
139 }
140
141 pub fn set_hnsw_bytes(&self, bytes: usize) {
142 self.hnsw_bytes.store(bytes, Ordering::SeqCst);
143 }
144
145 pub fn hnsw(&self) -> &OnceLock<RwLock<Option<HnswIndex>>> {
146 &self.hnsw
147 }
148
149 pub fn storage_bytes_per_entry(&self) -> Vec<usize> {
150 self.vectors
151 .read()
152 .iter()
153 .map(StoredVectorEntry::estimated_bytes)
154 .collect()
155 }
156}
157
158pub struct VectorIndexInfo {
159 pub index: VectorIndexRef,
160 pub dimension: usize,
161 pub quantization: VectorQuantization,
162 pub vector_count: usize,
163 pub bytes: usize,
164}
165
166pub struct VectorStore {
167 registry: RwLock<HashMap<VectorIndexRef, Arc<IndexState>>>,
168 build_mutex: Mutex<()>,
169}
170
171impl Default for VectorStore {
172 fn default() -> Self {
173 Self::new(Arc::new(OnceLock::new()))
174 }
175}
176
177impl VectorStore {
178 pub fn new(_legacy_hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>) -> Self {
179 Self {
180 registry: RwLock::new(HashMap::new()),
181 build_mutex: Mutex::new(()),
182 }
183 }
184
185 pub fn register_index(
186 &self,
187 index: VectorIndexRef,
188 dimension: usize,
189 quantization: VectorQuantization,
190 ) {
191 let mut registry = self.registry.write();
192 registry
193 .entry(index)
194 .or_insert_with(|| Arc::new(IndexState::new(dimension, quantization)));
195 }
196
197 pub fn register_or_reconfigure_empty_index(
198 &self,
199 index: VectorIndexRef,
200 dimension: usize,
201 quantization: VectorQuantization,
202 ) {
203 let mut registry = self.registry.write();
204 match registry.get(&index) {
205 Some(state) if state.entry_count() != 0 => {}
206 Some(state) if state.dimension() == dimension => {}
207 Some(_) | None => {
208 registry.insert(index, Arc::new(IndexState::new(dimension, quantization)));
209 }
210 }
211 }
212
213 pub fn deregister_index(&self, index: &VectorIndexRef, accountant: &MemoryAccountant) {
214 if let Some(state) = self.registry.write().remove(index) {
215 state.clear_hnsw(accountant);
216 }
217 }
218
219 pub fn deregister_table(&self, table: &str, accountant: &MemoryAccountant) {
220 let removed = {
221 let mut registry = self.registry.write();
222 let keys = registry
223 .keys()
224 .filter(|index| index.table == table)
225 .cloned()
226 .collect::<Vec<_>>();
227 keys.into_iter()
228 .filter_map(|key| registry.remove(&key))
229 .collect::<Vec<_>>()
230 };
231 for state in removed {
232 state.clear_hnsw(accountant);
233 }
234 }
235
236 pub fn rename_index(&self, old: &VectorIndexRef, new: VectorIndexRef) -> Result<()> {
237 let mut registry = self.registry.write();
238 if registry.contains_key(&new) {
239 return Err(Error::Other(format!(
240 "vector index already exists: {}.{}",
241 new.table, new.column
242 )));
243 }
244 let state = registry
245 .remove(old)
246 .ok_or_else(|| Error::UnknownVectorIndex { index: old.clone() })?;
247 registry.insert(new, state);
248 Ok(())
249 }
250
251 pub fn state(&self, index: &VectorIndexRef) -> Result<Arc<IndexState>> {
252 self.registry
253 .read()
254 .get(index)
255 .cloned()
256 .ok_or_else(|| Error::UnknownVectorIndex {
257 index: index.clone(),
258 })
259 }
260
261 pub fn try_state(&self, index: &VectorIndexRef) -> Option<Arc<IndexState>> {
262 self.registry.read().get(index).cloned()
263 }
264
265 pub fn is_empty(&self) -> bool {
266 self.registry.read().is_empty()
267 }
268
269 pub fn index_count(&self) -> usize {
270 self.registry.read().len()
271 }
272
273 pub fn validate_vector(&self, index: &VectorIndexRef, actual: usize) -> Result<()> {
274 let state = self.state(index)?;
275 let expected = state.dimension();
276 if expected != actual {
277 return Err(Error::VectorIndexDimensionMismatch {
278 index: index.clone(),
279 expected,
280 actual,
281 });
282 }
283 Ok(())
284 }
285
286 pub fn apply_inserts(&self, inserts: Vec<VectorEntry>) {
287 for entry in inserts {
288 if let Some(state) = self.try_state(&entry.index) {
289 let stored_entry = state.stored_entry(entry);
290 let row_id = stored_entry.row_id;
291 state.push_entry(stored_entry.clone());
292 if let Some(lock) = state.hnsw().get() {
293 let guard = lock.write();
294 if let Some(hnsw) = guard.as_ref() {
295 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
296 hnsw.insert(row_id, &stored_entry.vector);
297 }));
298 }
299 }
300 }
301 }
302 }
303
304 pub fn apply_deletes(&self, deletes: Vec<(VectorIndexRef, RowId, TxId)>) {
305 for (index, row_id, deleted_tx) in deletes {
306 if let Some(state) = self.try_state(&index) {
307 state.tombstone_row(row_id, deleted_tx);
308 if let Some(lock) = state.hnsw().get() {
309 *lock.write() = None;
310 }
311 }
312 }
313 }
314
315 pub fn apply_moves(
316 &self,
317 moves: Vec<(VectorIndexRef, RowId, RowId, TxId)>,
318 lsn: contextdb_core::Lsn,
319 ) {
320 for (index, old_row_id, new_row_id, tx) in moves {
321 if let Some(state) = self.try_state(&index)
322 && let Some(old) = state.stored_by_row_id(old_row_id)
323 && old.deleted_tx.is_none()
324 {
325 state.tombstone_row(old_row_id, tx);
326 let mut moved = old;
327 moved.row_id = new_row_id;
328 moved.created_tx = tx;
329 moved.deleted_tx = None;
330 moved.lsn = lsn;
331 state.push_entry(moved.clone());
332 if let Some(lock) = state.hnsw().get() {
333 let guard = lock.write();
334 if let Some(hnsw) = guard.as_ref() {
335 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
336 hnsw.insert(moved.row_id, &moved.vector);
337 }));
338 }
339 }
340 }
341 }
342 }
343
344 pub fn insert_loaded_vector(&self, entry: VectorEntry) {
345 let quantization = self
346 .try_state(&entry.index)
347 .map(|state| state.quantization())
348 .unwrap_or(VectorQuantization::F32);
349 self.register_or_reconfigure_empty_index(
350 entry.index.clone(),
351 entry.vector.len(),
352 quantization,
353 );
354 if let Some(state) = self.try_state(&entry.index) {
355 let stored_entry = state.stored_entry(entry);
356 state.push_entry(stored_entry);
357 }
358 }
359
360 pub fn all_entries(&self) -> Vec<VectorEntry> {
361 self.registry
362 .read()
363 .iter()
364 .flat_map(|(index, state)| state.all_entries(index))
365 .collect()
366 }
367
368 pub fn prune_row_ids(
369 &self,
370 row_ids: &std::collections::HashSet<RowId>,
371 accountant: &MemoryAccountant,
372 ) -> usize {
373 let mut released = 0usize;
374 for state in self.registry.read().values() {
375 let mut vectors = state.vectors.write();
376 vectors.retain(|entry| {
377 if row_ids.contains(&entry.row_id) {
378 released = released.saturating_add(entry.estimated_bytes());
379 false
380 } else {
381 true
382 }
383 });
384 drop(vectors);
385 state.clear_hnsw(accountant);
386 }
387 released
388 }
389
390 pub fn entries_for_index(&self, index: &VectorIndexRef) -> Result<Vec<VectorEntry>> {
391 Ok(self.state(index)?.all_entries(index))
392 }
393
394 pub fn vector_count(&self) -> usize {
395 self.registry
396 .read()
397 .values()
398 .map(|state| state.vector_count())
399 .sum()
400 }
401
402 pub fn has_hnsw_index(&self) -> bool {
403 self.registry
404 .read()
405 .values()
406 .any(|state| state.hnsw_len().is_some())
407 }
408
409 pub fn has_hnsw_index_for(&self, index: &VectorIndexRef) -> bool {
410 self.try_state(index)
411 .and_then(|state| state.hnsw_len())
412 .is_some()
413 }
414
415 pub fn clear_hnsw(&self, accountant: &MemoryAccountant) {
416 for state in self.registry.read().values() {
417 state.clear_hnsw(accountant);
418 }
419 }
420
421 pub fn clear_hnsw_for(&self, index: &VectorIndexRef, accountant: &MemoryAccountant) {
422 if let Some(state) = self.try_state(index) {
423 state.clear_hnsw(accountant);
424 }
425 }
426
427 pub fn find_by_row_id(&self, row_id: RowId) -> Option<VectorEntry> {
428 self.registry
429 .read()
430 .iter()
431 .find_map(|(index, state)| state.find_by_row_id(index, row_id))
432 }
433
434 pub fn live_entry_for_row(
435 &self,
436 index: &VectorIndexRef,
437 row_id: RowId,
438 snapshot: contextdb_core::SnapshotId,
439 ) -> Option<VectorEntry> {
440 self.try_state(index).and_then(|state| {
441 state.with_entries(|entries| {
442 entries
443 .iter()
444 .rev()
445 .find(|entry| entry.row_id == row_id && entry.visible_at(snapshot))
446 .map(|entry| entry.to_vector_entry(index.clone()))
447 })
448 })
449 }
450
451 pub fn live_entries_for_row(
452 &self,
453 row_id: RowId,
454 snapshot: contextdb_core::SnapshotId,
455 ) -> Vec<VectorEntry> {
456 self.registry
457 .read()
458 .iter()
459 .flat_map(|(index, state)| state.all_entries(index))
460 .filter(|entry| entry.row_id == row_id && entry.visible_at(snapshot))
461 .collect()
462 }
463
464 pub fn vector_for_row_lsn(
465 &self,
466 index: &VectorIndexRef,
467 row_id: RowId,
468 lsn: contextdb_core::Lsn,
469 ) -> Option<Vec<f32>> {
470 self.try_state(index).and_then(|state| {
471 state.with_entries(|entries| {
472 entries
473 .iter()
474 .find(|entry| entry.row_id == row_id && entry.lsn == lsn)
475 .map(|entry| entry.vector.to_f32())
476 })
477 })
478 }
479
480 pub fn storage_bytes_per_entry(&self, index: &VectorIndexRef) -> Result<Vec<usize>> {
481 Ok(self.state(index)?.storage_bytes_per_entry())
482 }
483
484 pub fn index_infos(&self) -> Vec<VectorIndexInfo> {
485 let mut infos = self
486 .registry
487 .read()
488 .iter()
489 .map(|(index, state)| VectorIndexInfo {
490 index: index.clone(),
491 dimension: state.dimension(),
492 quantization: state.quantization(),
493 vector_count: state.vector_count(),
494 bytes: state.byte_count(),
495 })
496 .collect::<Vec<_>>();
497 infos.sort_by(|a, b| {
498 a.index
499 .table
500 .cmp(&b.index.table)
501 .then(a.index.column.cmp(&b.index.column))
502 });
503 infos
504 }
505
506 pub fn build_lock(&self) -> parking_lot::MutexGuard<'_, ()> {
507 self.build_mutex.lock()
508 }
509}