Skip to main content

contextdb_vector/
mem.rs

1use crate::{HnswIndex, cosine::cosine_similarity, store::VectorStore};
2use contextdb_core::*;
3use contextdb_tx::{TxManager, WriteSetApplicator};
4use parking_lot::RwLock;
5use roaring::RoaringTreemap;
6use std::sync::{Arc, OnceLock};
7
8const HNSW_THRESHOLD: usize = 1000;
9
10pub struct MemVectorExecutor<S: WriteSetApplicator> {
11    store: Arc<VectorStore>,
12    tx_mgr: Arc<TxManager<S>>,
13    hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
14    accountant: Arc<MemoryAccountant>,
15}
16
17impl<S: WriteSetApplicator> MemVectorExecutor<S> {
18    pub fn new(
19        store: Arc<VectorStore>,
20        tx_mgr: Arc<TxManager<S>>,
21        hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
22    ) -> Self {
23        Self::new_with_accountant(store, tx_mgr, hnsw, Arc::new(MemoryAccountant::no_limit()))
24    }
25
26    pub fn new_with_accountant(
27        store: Arc<VectorStore>,
28        tx_mgr: Arc<TxManager<S>>,
29        hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
30        accountant: Arc<MemoryAccountant>,
31    ) -> Self {
32        Self {
33            store,
34            tx_mgr,
35            hnsw,
36            accountant,
37        }
38    }
39
40    fn brute_force_search(
41        &self,
42        query: &[f32],
43        k: usize,
44        candidates: Option<&RoaringTreemap>,
45        snapshot: SnapshotId,
46    ) -> Vec<(RowId, f32)> {
47        let vectors = self.store.vectors.read();
48        let mut scored: Vec<(RowId, f32)> = Vec::new();
49
50        for entry in vectors.iter() {
51            if !entry.visible_at(snapshot) {
52                continue;
53            }
54
55            if let Some(cands) = candidates
56                && !cands.contains(entry.row_id.0)
57            {
58                continue;
59            }
60
61            let sim = cosine_similarity(query, &entry.vector);
62            scored.push((entry.row_id, sim));
63        }
64
65        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
66        scored.truncate(k);
67        scored
68    }
69
70    fn build_hnsw_from_store(&self) -> Option<HnswIndex> {
71        let entries = self.store.all_entries();
72        let dim = self.store.dimension().unwrap_or(0);
73        let estimated_bytes = estimate_hnsw_bytes(entries.len(), dim);
74        if self
75            .accountant
76            .try_allocate_for(
77                estimated_bytes,
78                "vector_index",
79                "build_hnsw",
80                "Reduce vector volume or raise MEMORY_LIMIT so the HNSW index can be built.",
81            )
82            .is_err()
83        {
84            return None;
85        }
86
87        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
88            HnswIndex::new(&entries, dim)
89        }))
90        .ok()
91        .inspect(|_| self.store.set_hnsw_bytes(estimated_bytes))
92        .or_else(|| {
93            self.accountant.release(estimated_bytes);
94            None
95        })
96    }
97}
98
99impl<S: WriteSetApplicator> VectorExecutor for MemVectorExecutor<S> {
100    fn search(
101        &self,
102        query: &[f32],
103        k: usize,
104        candidates: Option<&RoaringTreemap>,
105        snapshot: SnapshotId,
106    ) -> Result<Vec<(RowId, f32)>> {
107        if k == 0 {
108            return Ok(Vec::new());
109        }
110
111        let use_hnsw = self.store.vector_count() >= HNSW_THRESHOLD;
112        if use_hnsw {
113            let once_lock = self
114                .hnsw
115                .get_or_init(|| RwLock::new(self.build_hnsw_from_store()));
116
117            {
118                let mut guard = once_lock.write();
119                if guard.is_none() {
120                    *guard = self.build_hnsw_from_store();
121                }
122            }
123
124            let guard = once_lock.read();
125            if let Some(hnsw) = guard.as_ref() {
126                let raw_candidates = hnsw.search(query, k)?;
127
128                // If the HNSW graph has disconnected components, the search
129                // may not reach all indexed vectors. Detect this and fall back
130                // to brute-force so we never silently drop results.
131                if raw_candidates.len() < hnsw.len() {
132                    return Ok(self.brute_force_search(query, k, candidates, snapshot));
133                }
134
135                let vectors = self.store.vectors.read();
136                let mut visible = raw_candidates
137                    .into_iter()
138                    .filter_map(|(rid, _)| {
139                        vectors
140                            .iter()
141                            .find(|entry| entry.row_id == rid && entry.visible_at(snapshot))
142                            .map(|entry| {
143                                if let Some(cands) = candidates
144                                    && !cands.contains(entry.row_id.0)
145                                {
146                                    return None;
147                                }
148
149                                Some((entry.row_id, cosine_similarity(query, &entry.vector)))
150                            })
151                            .unwrap_or(None)
152                    })
153                    .collect::<Vec<_>>();
154
155                visible.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
156                visible.truncate(k);
157                return Ok(visible);
158            }
159        }
160
161        Ok(self.brute_force_search(query, k, candidates, snapshot))
162    }
163
164    fn insert_vector(&self, tx: TxId, row_id: RowId, vector: Vec<f32>) -> Result<()> {
165        let got = vector.len();
166
167        {
168            let mut dim = self.store.dimension.write();
169            match *dim {
170                None => *dim = Some(got),
171                Some(expected) if expected != got => {
172                    return Err(Error::VectorDimensionMismatch { expected, got });
173                }
174                _ => {}
175            }
176        }
177
178        let entry = VectorEntry {
179            row_id,
180            vector,
181            created_tx: tx,
182            deleted_tx: None,
183            lsn: contextdb_core::Lsn(0),
184        };
185
186        self.tx_mgr.with_write_set(tx, |ws| {
187            ws.vector_inserts.push(entry);
188        })?;
189
190        Ok(())
191    }
192
193    fn delete_vector(&self, tx: TxId, row_id: RowId) -> Result<()> {
194        self.tx_mgr.with_write_set(tx, |ws| {
195            ws.vector_deletes.push((row_id, tx));
196        })?;
197
198        Ok(())
199    }
200}
201
202fn estimate_hnsw_bytes(entry_count: usize, dimension: usize) -> usize {
203    entry_count
204        .saturating_mul(dimension)
205        .saturating_mul(std::mem::size_of::<f32>())
206        .saturating_mul(3)
207}