1use arroy::distances::DotProduct;
4use heed::{types::*, RwTxn};
5use std::fmt::Debug;
6use std::sync::atomic::AtomicUsize;
7
8use arroy::{Database as ArroyDatabase, Reader, Writer};
9use heed::types::SerdeJson;
10use heed::{Database, EnvOpenOptions};
11use kalosm_language_model::*;
12use rand::rngs::StdRng;
13use rand::SeedableRng;
14use serde::{Deserialize, Serialize};
15
16pub type Candidates = roaring::RoaringBitmap;
18
19#[derive(Debug, thiserror::Error)]
21pub enum VectorDbError {
22 #[error("Arroy error: {0}")]
24 Arroy(#[from] arroy::Error),
25 #[error("Embedding {0:?} not found")]
27 EmbeddingNotFound(EmbeddingId),
28}
29
30impl From<heed::Error> for VectorDbError {
31 fn from(value: heed::Error) -> Self {
32 Self::Arroy(value.into())
33 }
34}
35
36#[doc(alias = "VectorDatabase")]
80#[doc(alias = "Vector Database")]
81pub struct VectorDB {
82 database: ArroyDatabase<DotProduct>,
83 metadata: Database<Str, SerdeJson<Vec<u32>>>,
84 env: heed::Env,
85 dim: AtomicUsize,
86}
87
88impl Default for VectorDB {
89 fn default() -> Self {
90 Self::new().unwrap()
91 }
92}
93
94impl VectorDB {
95 fn set_dim(&self, dim: usize) {
96 if dim == 0 {
97 panic!("Dimension cannot be 0");
98 }
99 self.dim.store(dim, std::sync::atomic::Ordering::Relaxed);
100 }
101
102 fn get_dim(&self) -> Result<usize, arroy::Error> {
103 let mut dims = self.dim.load(std::sync::atomic::Ordering::Relaxed);
104 if dims == 0 {
105 let rtxn = self.env.read_txn()?;
106 let reader = Reader::<DotProduct>::open(&rtxn, 0, self.database)?;
107 dims = reader.dimensions();
108 self.set_dim(dims);
109 }
110 Ok(dims)
111 }
112
113 #[tracing::instrument]
115 pub fn new() -> heed::Result<Self> {
116 let dir = tempfile::tempdir()?;
117
118 Self::new_at(dir.path())
119 }
120
121 pub fn new_at(path: impl AsRef<std::path::Path>) -> heed::Result<Self> {
123 const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024;
124
125 std::fs::create_dir_all(&path)?;
126
127 let env = unsafe {
128 EnvOpenOptions::new()
129 .map_size(TWENTY_HUNDRED_MIB)
130 .open(path)
131 }?;
132
133 let mut wtxn = env.write_txn()?;
134 let db: ArroyDatabase<DotProduct> = env.create_database(&mut wtxn, None)?;
135 let metadata: Database<Str, SerdeJson<Vec<u32>>> = env.create_database(&mut wtxn, None)?;
136 wtxn.commit()?;
137
138 Ok(Self {
139 database: db,
140 metadata,
141 env,
142 dim: AtomicUsize::new(0),
143 })
144 }
145
146 fn take_id(&self, wtxn: &mut RwTxn) -> Result<EmbeddingId, heed::Error> {
147 if let Some(mut free) = self.metadata.get(wtxn, "free")? {
148 if let Some(id) = free.pop() {
149 self.metadata.put(wtxn, "free", &free)?;
150 return Ok(EmbeddingId(id));
151 }
152 }
153 match self.metadata.get(wtxn, "max")? {
154 Some(max) => {
155 let id = max[0];
156 self.metadata.put(wtxn, "max", &vec![id + 1])?;
157 Ok(EmbeddingId(id))
158 }
159 None => {
160 self.metadata.put(wtxn, "max", &vec![1])?;
161 Ok(EmbeddingId(0))
162 }
163 }
164 }
165
166 fn recycle_id(&self, id: EmbeddingId, wtxn: &mut RwTxn) -> Result<(), heed::Error> {
167 let mut free = self.metadata.get(wtxn, "free")?.unwrap_or_default();
168 free.push(id.0);
169 self.metadata.put(wtxn, "free", &free)?;
170
171 Ok(())
172 }
173
174 pub fn raw(&self) -> (&ArroyDatabase<DotProduct>, &heed::Env) {
176 (&self.database, &self.env)
177 }
178
179 pub async fn clear(&self) -> Result<(), arroy::Error> {
181 let mut wtxn = self.env.write_txn()?;
182 let dims = self.get_dim()?;
183 let writer = Writer::<DotProduct>::new(self.database, 0, dims);
184 writer.clear(&mut wtxn)?;
185
186 self.metadata.put(&mut wtxn, "max", &vec![0])?;
188 self.metadata.put(&mut wtxn, "free", &vec![])?;
189 wtxn.commit()?;
190
191 Ok(())
192 }
193
194 pub fn rebuild(
196 &self,
197 writer: &mut Writer<DotProduct>,
198 wtxn: &mut RwTxn,
199 ) -> Result<(), arroy::Error> {
200 let mut rng = StdRng::from_entropy();
201 writer.builder(&mut rng).build(wtxn)?;
202
203 Ok(())
204 }
205
206 pub fn remove_embedding(&self, embedding_id: EmbeddingId) -> Result<(), arroy::Error> {
208 let dims = self.get_dim()?;
209
210 let mut wtxn = self.env.write_txn()?;
211
212 let mut writer = Writer::<DotProduct>::new(self.database, 0, dims);
213
214 writer.del_item(&mut wtxn, embedding_id.0)?;
215 self.recycle_id(embedding_id, &mut wtxn)?;
216
217 self.rebuild(&mut writer, &mut wtxn)?;
218
219 wtxn.commit()?;
220
221 Ok(())
222 }
223
224 pub fn add_embedding(&self, embedding: Embedding) -> Result<EmbeddingId, VectorDbError> {
228 let embedding = embedding.vector();
229
230 self.set_dim(embedding.len());
231
232 let mut wtxn = self.env.write_txn()?;
233
234 let mut writer = Writer::<DotProduct>::new(self.database, 0, embedding.len());
235
236 let id = self.take_id(&mut wtxn)?;
237
238 writer.add_item(&mut wtxn, id.0, embedding)?;
239
240 self.rebuild(&mut writer, &mut wtxn)?;
241
242 wtxn.commit()?;
243
244 Ok(id)
245 }
246
247 pub fn add_embeddings(
249 &self,
250 embedding: impl IntoIterator<Item = Embedding>,
251 ) -> Result<Vec<EmbeddingId>, VectorDbError> {
252 let mut embeddings = embedding
253 .into_iter()
254 .map(|e| e.vector().to_vec().into_boxed_slice());
255 let Some(first_embedding) = embeddings.next() else {
256 return Ok(Vec::new());
257 };
258 self.set_dim(first_embedding.len());
259
260 let mut wtxn = self.env.write_txn()?;
261 let mut writer = Writer::<DotProduct>::new(self.database, 0, first_embedding.len());
262
263 let mut ids: Vec<_> = Vec::with_capacity(embeddings.size_hint().0 + 1);
264
265 {
266 let first_id = self.take_id(&mut wtxn)?;
267 writer.add_item(&mut wtxn, first_id.0, &first_embedding)?;
268 ids.push(first_id);
269 }
270
271 for embedding in embeddings {
272 let id = self.take_id(&mut wtxn)?;
273 writer.add_item(&mut wtxn, id.0, &embedding)?;
274 ids.push(id);
275 }
276
277 self.rebuild(&mut writer, &mut wtxn)?;
278
279 wtxn.commit()?;
280
281 Ok(ids)
282 }
283
284 pub fn get_embedding(&self, embedding_id: EmbeddingId) -> Result<Embedding, VectorDbError> {
286 let rtxn = self.env.read_txn()?;
287 let reader = Reader::<DotProduct>::open(&rtxn, 0, self.database)?;
288
289 let embedding = reader
290 .item_vector(&rtxn, embedding_id.0)?
291 .ok_or_else(|| VectorDbError::EmbeddingNotFound(embedding_id))?;
292
293 Ok(Embedding::from(embedding))
294 }
295
296 pub fn search<'a>(&'a self, embedding: &'a Embedding) -> VectorDBSearchBuilder<'a> {
298 VectorDBSearchBuilder {
299 db: self,
300 embedding,
301 results: None,
302 filter: None,
303 }
304 }
305}
306
307pub trait IntoVectorDbSearchFilter<M> {
309 fn into_vector_db_search_filter(self, db: &VectorDB) -> Candidates;
311}
312
313impl IntoVectorDbSearchFilter<()> for Candidates {
314 fn into_vector_db_search_filter(self, _: &VectorDB) -> Candidates {
315 self
316 }
317}
318
319pub struct IteratorMarker;
321
322impl<I> IntoVectorDbSearchFilter<IteratorMarker> for I
323where
324 I: IntoIterator<Item = EmbeddingId>,
325{
326 fn into_vector_db_search_filter(self, _: &VectorDB) -> Candidates {
327 let mut candidates = Candidates::new();
328 for id in self {
329 candidates.insert(id.0);
330 }
331 candidates
332 }
333}
334
335pub struct ClosureMarker;
337
338impl<I> IntoVectorDbSearchFilter<ClosureMarker> for I
339where
340 I: FnMut(Embedding) -> bool,
341{
342 fn into_vector_db_search_filter(mut self, db: &VectorDB) -> Candidates {
343 let mut candidates = Candidates::new();
344 let rtxn = match db.env.read_txn() {
345 Ok(rtxn) => rtxn,
346 Err(err) => {
347 tracing::error!("Error opening read transaction: {:?}", err);
348 return candidates;
349 }
350 };
351 let reader = match Reader::<DotProduct>::open(&rtxn, 0, db.database) {
352 Ok(reader) => reader,
353 Err(err) => {
354 tracing::error!("Error opening reader: {:?}", err);
355 return candidates;
356 }
357 };
358 for (key, tensor) in reader.iter(&rtxn).ok().into_iter().flatten().flatten() {
359 let embedding = Embedding::from(tensor);
360 if self(embedding) {
361 candidates.insert(key);
362 }
363 }
364 candidates
365 }
366}
367
368pub struct VectorDBSearchBuilder<'a> {
370 db: &'a VectorDB,
371 embedding: &'a Embedding,
372 results: Option<usize>,
373 filter: Option<Candidates>,
374}
375
376impl VectorDBSearchBuilder<'_> {
377 pub fn with_results(mut self, results: usize) -> Self {
379 self.results = Some(results);
380 self
381 }
382
383 pub fn with_filter<Marker>(
385 mut self,
386 filter: impl IntoVectorDbSearchFilter<Marker> + Send + Sync + 'static,
387 ) -> Self {
388 self.filter = Some(filter.into_vector_db_search_filter(self.db));
389 self
390 }
391
392 pub fn run(self) -> Result<Vec<VectorDBSearchResult>, VectorDbError> {
394 let rtxn = self.db.env.read_txn()?;
395 let reader = Reader::<DotProduct>::open(&rtxn, 0, self.db.database)?;
396
397 let vector = self.embedding.vector();
398 let mut query = reader.nns(self.results.unwrap_or(10));
399 if let Some(filter) = self.filter.as_ref() {
400 query.candidates(filter);
401 }
402 let arroy_results = query.by_vector(&rtxn, vector)?;
403
404 Ok(arroy_results
405 .into_iter()
406 .map(|(id, distance)| {
407 let value = EmbeddingId(id);
408 VectorDBSearchResult { distance, value }
409 })
410 .collect::<Vec<_>>())
411 }
412}
413
414#[derive(Debug, Clone, PartialEq)]
416pub struct VectorDBSearchResult {
417 pub distance: f32,
419 pub value: EmbeddingId,
421}
422
423#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
425pub struct EmbeddingId(pub u32);
426
427#[tokio::test]
428async fn test_vector_db_get_closest() {
429 let db: VectorDB = VectorDB::new().unwrap();
430 let first_vector = Embedding::from([1.0, 2.0, 3.0]);
431 let second_embedding = Embedding::from([-1.0, 2.0, 3.0]);
432 let id1 = db.add_embedding(first_vector.clone()).unwrap();
433 let id2 = db.add_embedding(second_embedding.clone()).unwrap();
434 assert_eq!(
435 db.search(&first_vector)
436 .with_results(1)
437 .run()
438 .unwrap()
439 .iter()
440 .map(|r| r.value)
441 .collect::<Vec<_>>(),
442 vec![id1]
443 );
444 assert_eq!(
445 db.search(&second_embedding)
446 .with_results(1)
447 .run()
448 .unwrap()
449 .iter()
450 .map(|r| r.value)
451 .collect::<Vec<_>>(),
452 vec![id2]
453 );
454 let third_embedding = Embedding::from([1.0, 0.0, 0.0]);
455 assert_eq!(
456 db.search(&third_embedding)
457 .with_results(1)
458 .run()
459 .unwrap()
460 .iter()
461 .map(|r| r.value)
462 .collect::<Vec<_>>(),
463 vec![id1]
464 );
465 assert_eq!(
466 db.search(&third_embedding)
467 .with_filter(|vector: Embedding| vector.vector()[0] < 0.0)
468 .run()
469 .unwrap()
470 .iter()
471 .map(|r| r.value)
472 .collect::<Vec<_>>(),
473 vec![id2]
474 );
475}