Skip to main content

contextdb_vector/
mem.rs

1use crate::{HnswIndex, store::VectorStore};
2use contextdb_core::*;
3use contextdb_tx::{TxManager, WriteSetApplicator};
4use parking_lot::RwLock;
5use roaring::RoaringTreemap;
6use std::collections::HashSet;
7use std::sync::{Arc, OnceLock};
8
9const HNSW_THRESHOLD: usize = 1000;
10
11pub struct MemVectorExecutor<S: WriteSetApplicator> {
12    store: Arc<VectorStore>,
13    tx_mgr: Arc<TxManager<S>>,
14    accountant: Arc<MemoryAccountant>,
15}
16
17impl<S: WriteSetApplicator> MemVectorExecutor<S> {
18    pub fn new(
19        store: Arc<VectorStore>,
20        tx_mgr: Arc<TxManager<S>>,
21        hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
22    ) -> Self {
23        Self::new_with_accountant(store, tx_mgr, hnsw, Arc::new(MemoryAccountant::no_limit()))
24    }
25
26    pub fn new_with_accountant(
27        store: Arc<VectorStore>,
28        tx_mgr: Arc<TxManager<S>>,
29        _hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
30        accountant: Arc<MemoryAccountant>,
31    ) -> Self {
32        Self {
33            store,
34            tx_mgr,
35            accountant,
36        }
37    }
38
39    fn brute_force_search(
40        &self,
41        index: &VectorIndexRef,
42        query: &[f32],
43        k: usize,
44        candidates: Option<&RoaringTreemap>,
45        snapshot: SnapshotId,
46    ) -> Result<Vec<(RowId, f32)>> {
47        let state = self.store.state(index)?;
48        if query.len() != state.dimension() {
49            return Err(Error::VectorIndexDimensionMismatch {
50                index: index.clone(),
51                expected: state.dimension(),
52                actual: query.len(),
53            });
54        }
55
56        let mut scored: Vec<(RowId, f32)> = state.with_entries(|entries| {
57            let mut scored = Vec::new();
58            for entry in entries {
59                if !entry.visible_at(snapshot) {
60                    continue;
61                }
62
63                if let Some(cands) = candidates
64                    && !cands.contains(entry.row_id.0)
65                {
66                    continue;
67                }
68
69                let sim = entry.vector.cosine_similarity(query);
70                scored.push((entry.row_id, sim));
71            }
72            scored
73        });
74
75        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
76        scored.truncate(k);
77        Ok(scored)
78    }
79
80    fn build_hnsw_from_store(&self, index: &VectorIndexRef) -> Option<HnswIndex> {
81        let _build_guard = self.store.build_lock();
82        let state = self.store.try_state(index)?;
83        let dim = state.dimension();
84        let entry_count = state.entry_count();
85        let estimated_bytes = estimate_hnsw_bytes(entry_count, dim, state.quantization());
86        if self
87            .accountant
88            .try_allocate_for(
89                estimated_bytes,
90                "vector_index",
91                &format!("build_hnsw@{}.{}", index.table, index.column),
92                "Reduce vector volume or raise MEMORY_LIMIT so the HNSW index can be built.",
93            )
94            .is_err()
95        {
96            return None;
97        }
98
99        let built = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
100            state.with_entries(|entries| HnswIndex::new(entries, dim, state.quantization()))
101        }))
102        .ok();
103        if built.is_none() {
104            self.accountant.release(estimated_bytes);
105        } else {
106            state.set_hnsw_bytes(estimated_bytes);
107        }
108        built
109    }
110}
111
112impl<S: WriteSetApplicator> VectorExecutor for MemVectorExecutor<S> {
113    fn search(
114        &self,
115        index: VectorIndexRef,
116        query: &[f32],
117        k: usize,
118        candidates: Option<&RoaringTreemap>,
119        snapshot: SnapshotId,
120    ) -> Result<Vec<(RowId, f32)>> {
121        if k == 0 {
122            return Ok(Vec::new());
123        }
124        let Some(state) = self.store.try_state(&index) else {
125            return Err(Error::UnknownVectorIndex { index });
126        };
127        if query.len() != state.dimension() {
128            return Err(Error::VectorIndexDimensionMismatch {
129                index,
130                expected: state.dimension(),
131                actual: query.len(),
132            });
133        }
134        if state.entry_count() == 0 {
135            return Ok(Vec::new());
136        }
137
138        let use_hnsw = state.entry_count() >= HNSW_THRESHOLD;
139        if use_hnsw {
140            let lock = state
141                .hnsw()
142                .get_or_init(|| RwLock::new(self.build_hnsw_from_store(&index)));
143
144            {
145                let mut guard = lock.write();
146                if guard.is_none() {
147                    *guard = self.build_hnsw_from_store(&index);
148                }
149            }
150
151            let guard = lock.read();
152            if let Some(hnsw) = guard.as_ref() {
153                let raw_candidates = hnsw.search(&index, query, k)?;
154                let raw_candidate_count = raw_candidates.len();
155
156                if candidates.is_some() && raw_candidates.len() < hnsw.len() {
157                    return self.brute_force_search(&index, query, k, candidates, snapshot);
158                }
159
160                let supplement_missing = raw_candidate_count.saturating_add(64) >= hnsw.len();
161                let raw_row_ids = if supplement_missing {
162                    raw_candidates
163                        .iter()
164                        .map(|(row_id, _)| *row_id)
165                        .collect::<HashSet<_>>()
166                } else {
167                    HashSet::new()
168                };
169                let mut visible = state.with_entries(|entries| {
170                    let mut visible = raw_candidates
171                        .into_iter()
172                        .filter_map(|(rid, _)| {
173                            entries
174                                .iter()
175                                .find(|entry| entry.row_id == rid && entry.visible_at(snapshot))
176                                .and_then(|entry| {
177                                    if let Some(cands) = candidates
178                                        && !cands.contains(entry.row_id.0)
179                                    {
180                                        return None;
181                                    }
182                                    Some((entry.row_id, entry.vector.cosine_similarity(query)))
183                                })
184                        })
185                        .collect::<Vec<_>>();
186                    if supplement_missing {
187                        for entry in entries {
188                            if raw_row_ids.contains(&entry.row_id) || !entry.visible_at(snapshot) {
189                                continue;
190                            }
191                            visible.push((entry.row_id, entry.vector.cosine_similarity(query)));
192                        }
193                    }
194                    visible
195                });
196
197                visible.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198                if visible.len() < k && raw_candidate_count < hnsw.len() {
199                    return self.brute_force_search(&index, query, k, candidates, snapshot);
200                }
201                visible.truncate(k);
202                return Ok(visible);
203            }
204        }
205
206        self.brute_force_search(&index, query, k, candidates, snapshot)
207    }
208
209    fn insert_vector(
210        &self,
211        tx: TxId,
212        index: VectorIndexRef,
213        row_id: RowId,
214        vector: Vec<f32>,
215    ) -> Result<()> {
216        self.store.validate_vector(&index, vector.len())?;
217        let entry = VectorEntry {
218            index,
219            row_id,
220            vector,
221            created_tx: tx,
222            deleted_tx: None,
223            lsn: contextdb_core::Lsn(0),
224        };
225
226        self.tx_mgr.with_write_set(tx, |ws| {
227            ws.vector_inserts.push(entry);
228        })?;
229
230        Ok(())
231    }
232
233    fn delete_vector(&self, tx: TxId, index: VectorIndexRef, row_id: RowId) -> Result<()> {
234        self.store.state(&index)?;
235        self.tx_mgr.with_write_set(tx, |ws| {
236            ws.vector_deletes.push((index, row_id, tx));
237        })?;
238
239        Ok(())
240    }
241}
242
243fn estimate_hnsw_bytes(
244    entry_count: usize,
245    dimension: usize,
246    quantization: VectorQuantization,
247) -> usize {
248    let entry_bytes = match quantization {
249        VectorQuantization::F32 => quantization.storage_bytes(dimension),
250        VectorQuantization::SQ8 => dimension.saturating_add(12),
251        VectorQuantization::SQ4 => dimension.div_ceil(2).saturating_add(12),
252    };
253    entry_count.saturating_mul(entry_bytes).saturating_mul(3)
254}