1use crate::{HnswIndex, cosine::cosine_similarity, store::VectorStore};
2use contextdb_core::*;
3use contextdb_tx::{TxManager, WriteSetApplicator};
4use parking_lot::RwLock;
5use roaring::RoaringTreemap;
6use std::sync::{Arc, OnceLock};
7
8const HNSW_THRESHOLD: usize = 1000;
9
10pub struct MemVectorExecutor<S: WriteSetApplicator> {
11 store: Arc<VectorStore>,
12 tx_mgr: Arc<TxManager<S>>,
13 hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
14 accountant: Arc<MemoryAccountant>,
15}
16
17impl<S: WriteSetApplicator> MemVectorExecutor<S> {
18 pub fn new(
19 store: Arc<VectorStore>,
20 tx_mgr: Arc<TxManager<S>>,
21 hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
22 ) -> Self {
23 Self::new_with_accountant(store, tx_mgr, hnsw, Arc::new(MemoryAccountant::no_limit()))
24 }
25
26 pub fn new_with_accountant(
27 store: Arc<VectorStore>,
28 tx_mgr: Arc<TxManager<S>>,
29 hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
30 accountant: Arc<MemoryAccountant>,
31 ) -> Self {
32 Self {
33 store,
34 tx_mgr,
35 hnsw,
36 accountant,
37 }
38 }
39
40 fn brute_force_search(
41 &self,
42 query: &[f32],
43 k: usize,
44 candidates: Option<&RoaringTreemap>,
45 snapshot: SnapshotId,
46 ) -> Vec<(RowId, f32)> {
47 let vectors = self.store.vectors.read();
48 let mut scored: Vec<(RowId, f32)> = Vec::new();
49
50 for entry in vectors.iter() {
51 if !entry.visible_at(snapshot) {
52 continue;
53 }
54
55 if let Some(cands) = candidates
56 && !cands.contains(entry.row_id.0)
57 {
58 continue;
59 }
60
61 let sim = cosine_similarity(query, &entry.vector);
62 scored.push((entry.row_id, sim));
63 }
64
65 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
66 scored.truncate(k);
67 scored
68 }
69
70 fn build_hnsw_from_store(&self) -> Option<HnswIndex> {
71 let entries = self.store.all_entries();
72 let dim = self.store.dimension().unwrap_or(0);
73 let estimated_bytes = estimate_hnsw_bytes(entries.len(), dim);
74 if self
75 .accountant
76 .try_allocate_for(
77 estimated_bytes,
78 "vector_index",
79 "build_hnsw",
80 "Reduce vector volume or raise MEMORY_LIMIT so the HNSW index can be built.",
81 )
82 .is_err()
83 {
84 return None;
85 }
86
87 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
88 HnswIndex::new(&entries, dim)
89 }))
90 .ok()
91 .inspect(|_| self.store.set_hnsw_bytes(estimated_bytes))
92 .or_else(|| {
93 self.accountant.release(estimated_bytes);
94 None
95 })
96 }
97}
98
99impl<S: WriteSetApplicator> VectorExecutor for MemVectorExecutor<S> {
100 fn search(
101 &self,
102 query: &[f32],
103 k: usize,
104 candidates: Option<&RoaringTreemap>,
105 snapshot: SnapshotId,
106 ) -> Result<Vec<(RowId, f32)>> {
107 if k == 0 {
108 return Ok(Vec::new());
109 }
110
111 let use_hnsw = self.store.vector_count() >= HNSW_THRESHOLD;
112 if use_hnsw {
113 let once_lock = self
114 .hnsw
115 .get_or_init(|| RwLock::new(self.build_hnsw_from_store()));
116
117 {
118 let mut guard = once_lock.write();
119 if guard.is_none() {
120 *guard = self.build_hnsw_from_store();
121 }
122 }
123
124 let guard = once_lock.read();
125 if let Some(hnsw) = guard.as_ref() {
126 let raw_candidates = hnsw.search(query, k)?;
127
128 if raw_candidates.len() < hnsw.len() {
132 return Ok(self.brute_force_search(query, k, candidates, snapshot));
133 }
134
135 let vectors = self.store.vectors.read();
136 let mut visible = raw_candidates
137 .into_iter()
138 .filter_map(|(rid, _)| {
139 vectors
140 .iter()
141 .find(|entry| entry.row_id == rid && entry.visible_at(snapshot))
142 .map(|entry| {
143 if let Some(cands) = candidates
144 && !cands.contains(entry.row_id.0)
145 {
146 return None;
147 }
148
149 Some((entry.row_id, cosine_similarity(query, &entry.vector)))
150 })
151 .unwrap_or(None)
152 })
153 .collect::<Vec<_>>();
154
155 visible.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
156 visible.truncate(k);
157 return Ok(visible);
158 }
159 }
160
161 Ok(self.brute_force_search(query, k, candidates, snapshot))
162 }
163
164 fn insert_vector(&self, tx: TxId, row_id: RowId, vector: Vec<f32>) -> Result<()> {
165 let got = vector.len();
166
167 {
168 let mut dim = self.store.dimension.write();
169 match *dim {
170 None => *dim = Some(got),
171 Some(expected) if expected != got => {
172 return Err(Error::VectorDimensionMismatch { expected, got });
173 }
174 _ => {}
175 }
176 }
177
178 let entry = VectorEntry {
179 row_id,
180 vector,
181 created_tx: tx,
182 deleted_tx: None,
183 lsn: contextdb_core::Lsn(0),
184 };
185
186 self.tx_mgr.with_write_set(tx, |ws| {
187 ws.vector_inserts.push(entry);
188 })?;
189
190 Ok(())
191 }
192
193 fn delete_vector(&self, tx: TxId, row_id: RowId) -> Result<()> {
194 self.tx_mgr.with_write_set(tx, |ws| {
195 ws.vector_deletes.push((row_id, tx));
196 })?;
197
198 Ok(())
199 }
200}
201
202fn estimate_hnsw_bytes(entry_count: usize, dimension: usize) -> usize {
203 entry_count
204 .saturating_mul(dimension)
205 .saturating_mul(std::mem::size_of::<f32>())
206 .saturating_mul(3)
207}