1use crate::{HnswIndex, store::VectorStore};
2use contextdb_core::*;
3use contextdb_tx::{TxManager, WriteSetApplicator};
4use parking_lot::RwLock;
5use roaring::RoaringTreemap;
6use std::collections::HashSet;
7use std::sync::{Arc, OnceLock};
8
9const HNSW_THRESHOLD: usize = 1000;
10
11pub struct MemVectorExecutor<S: WriteSetApplicator> {
12 store: Arc<VectorStore>,
13 tx_mgr: Arc<TxManager<S>>,
14 accountant: Arc<MemoryAccountant>,
15}
16
17impl<S: WriteSetApplicator> MemVectorExecutor<S> {
18 pub fn new(
19 store: Arc<VectorStore>,
20 tx_mgr: Arc<TxManager<S>>,
21 hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
22 ) -> Self {
23 Self::new_with_accountant(store, tx_mgr, hnsw, Arc::new(MemoryAccountant::no_limit()))
24 }
25
26 pub fn new_with_accountant(
27 store: Arc<VectorStore>,
28 tx_mgr: Arc<TxManager<S>>,
29 _hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
30 accountant: Arc<MemoryAccountant>,
31 ) -> Self {
32 Self {
33 store,
34 tx_mgr,
35 accountant,
36 }
37 }
38
39 fn brute_force_search(
40 &self,
41 index: &VectorIndexRef,
42 query: &[f32],
43 k: usize,
44 candidates: Option<&RoaringTreemap>,
45 snapshot: SnapshotId,
46 ) -> Result<Vec<(RowId, f32)>> {
47 let state = self.store.state(index)?;
48 if query.len() != state.dimension() {
49 return Err(Error::VectorIndexDimensionMismatch {
50 index: index.clone(),
51 expected: state.dimension(),
52 actual: query.len(),
53 });
54 }
55
56 let mut scored: Vec<(RowId, f32)> = state.with_entries(|entries| {
57 let mut scored = Vec::new();
58 for entry in entries {
59 if !entry.visible_at(snapshot) {
60 continue;
61 }
62
63 if let Some(cands) = candidates
64 && !cands.contains(entry.row_id.0)
65 {
66 continue;
67 }
68
69 let sim = entry.vector.cosine_similarity(query);
70 scored.push((entry.row_id, sim));
71 }
72 scored
73 });
74
75 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
76 scored.truncate(k);
77 Ok(scored)
78 }
79
80 fn build_hnsw_from_store(&self, index: &VectorIndexRef) -> Option<HnswIndex> {
81 let _build_guard = self.store.build_lock();
82 let state = self.store.try_state(index)?;
83 let dim = state.dimension();
84 let entry_count = state.entry_count();
85 let estimated_bytes = estimate_hnsw_bytes(entry_count, dim, state.quantization());
86 if self
87 .accountant
88 .try_allocate_for(
89 estimated_bytes,
90 "vector_index",
91 &format!("build_hnsw@{}.{}", index.table, index.column),
92 "Reduce vector volume or raise MEMORY_LIMIT so the HNSW index can be built.",
93 )
94 .is_err()
95 {
96 return None;
97 }
98
99 let built = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
100 state.with_entries(|entries| HnswIndex::new(entries, dim, state.quantization()))
101 }))
102 .ok();
103 if built.is_none() {
104 self.accountant.release(estimated_bytes);
105 } else {
106 state.set_hnsw_bytes(estimated_bytes);
107 }
108 built
109 }
110}
111
112impl<S: WriteSetApplicator> VectorExecutor for MemVectorExecutor<S> {
113 fn search(
114 &self,
115 index: VectorIndexRef,
116 query: &[f32],
117 k: usize,
118 candidates: Option<&RoaringTreemap>,
119 snapshot: SnapshotId,
120 ) -> Result<Vec<(RowId, f32)>> {
121 if k == 0 {
122 return Ok(Vec::new());
123 }
124 let Some(state) = self.store.try_state(&index) else {
125 return Err(Error::UnknownVectorIndex { index });
126 };
127 if query.len() != state.dimension() {
128 return Err(Error::VectorIndexDimensionMismatch {
129 index,
130 expected: state.dimension(),
131 actual: query.len(),
132 });
133 }
134 if state.entry_count() == 0 {
135 return Ok(Vec::new());
136 }
137
138 let use_hnsw = state.entry_count() >= HNSW_THRESHOLD;
139 if use_hnsw {
140 let lock = state
141 .hnsw()
142 .get_or_init(|| RwLock::new(self.build_hnsw_from_store(&index)));
143
144 {
145 let mut guard = lock.write();
146 if guard.is_none() {
147 *guard = self.build_hnsw_from_store(&index);
148 }
149 }
150
151 let guard = lock.read();
152 if let Some(hnsw) = guard.as_ref() {
153 let raw_candidates = hnsw.search(&index, query, k)?;
154 let raw_candidate_count = raw_candidates.len();
155
156 if candidates.is_some() && raw_candidates.len() < hnsw.len() {
157 return self.brute_force_search(&index, query, k, candidates, snapshot);
158 }
159
160 let supplement_missing = raw_candidate_count.saturating_add(64) >= hnsw.len();
161 let raw_row_ids = if supplement_missing {
162 raw_candidates
163 .iter()
164 .map(|(row_id, _)| *row_id)
165 .collect::<HashSet<_>>()
166 } else {
167 HashSet::new()
168 };
169 let mut visible = state.with_entries(|entries| {
170 let mut visible = raw_candidates
171 .into_iter()
172 .filter_map(|(rid, _)| {
173 entries
174 .iter()
175 .find(|entry| entry.row_id == rid && entry.visible_at(snapshot))
176 .and_then(|entry| {
177 if let Some(cands) = candidates
178 && !cands.contains(entry.row_id.0)
179 {
180 return None;
181 }
182 Some((entry.row_id, entry.vector.cosine_similarity(query)))
183 })
184 })
185 .collect::<Vec<_>>();
186 if supplement_missing {
187 for entry in entries {
188 if raw_row_ids.contains(&entry.row_id) || !entry.visible_at(snapshot) {
189 continue;
190 }
191 visible.push((entry.row_id, entry.vector.cosine_similarity(query)));
192 }
193 }
194 visible
195 });
196
197 visible.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198 if visible.len() < k && raw_candidate_count < hnsw.len() {
199 return self.brute_force_search(&index, query, k, candidates, snapshot);
200 }
201 visible.truncate(k);
202 return Ok(visible);
203 }
204 }
205
206 self.brute_force_search(&index, query, k, candidates, snapshot)
207 }
208
209 fn insert_vector(
210 &self,
211 tx: TxId,
212 index: VectorIndexRef,
213 row_id: RowId,
214 vector: Vec<f32>,
215 ) -> Result<()> {
216 self.store.validate_vector(&index, vector.len())?;
217 let entry = VectorEntry {
218 index,
219 row_id,
220 vector,
221 created_tx: tx,
222 deleted_tx: None,
223 lsn: contextdb_core::Lsn(0),
224 };
225
226 self.tx_mgr.with_write_set(tx, |ws| {
227 ws.vector_inserts.push(entry);
228 })?;
229
230 Ok(())
231 }
232
233 fn delete_vector(&self, tx: TxId, index: VectorIndexRef, row_id: RowId) -> Result<()> {
234 self.store.state(&index)?;
235 self.tx_mgr.with_write_set(tx, |ws| {
236 ws.vector_deletes.push((index, row_id, tx));
237 })?;
238
239 Ok(())
240 }
241}
242
243fn estimate_hnsw_bytes(
244 entry_count: usize,
245 dimension: usize,
246 quantization: VectorQuantization,
247) -> usize {
248 let entry_bytes = match quantization {
249 VectorQuantization::F32 => quantization.storage_bytes(dimension),
250 VectorQuantization::SQ8 => dimension.saturating_add(12),
251 VectorQuantization::SQ4 => dimension.div_ceil(2).saturating_add(12),
252 };
253 entry_count.saturating_mul(entry_bytes).saturating_mul(3)
254}