kalosm_language/
vector_db.rs

1//! A vector database that can be used to store embeddings and search for similar embeddings.
2
3use arroy::distances::DotProduct;
4use heed::{types::*, RwTxn};
5use std::fmt::Debug;
6use std::sync::atomic::AtomicUsize;
7
8use arroy::{Database as ArroyDatabase, Reader, Writer};
9use heed::types::SerdeJson;
10use heed::{Database, EnvOpenOptions};
11use kalosm_language_model::*;
12use rand::rngs::StdRng;
13use rand::SeedableRng;
14use serde::{Deserialize, Serialize};
15
16/// A set of candidates for a vector search.
17pub type Candidates = roaring::RoaringBitmap;
18
19/// An error that can occur when adding or searching for an embedding to the vector database.
20#[derive(Debug, thiserror::Error)]
21pub enum VectorDbError {
22    /// An error from the arroy crate.
23    #[error("Arroy error: {0}")]
24    Arroy(#[from] arroy::Error),
25    /// An error from querying an embedding id that does not exist.
26    #[error("Embedding {0:?} not found")]
27    EmbeddingNotFound(EmbeddingId),
28}
29
30impl From<heed::Error> for VectorDbError {
31    fn from(value: heed::Error) -> Self {
32        Self::Arroy(value.into())
33    }
34}
35
36/// A vector database that can be used to store embeddings and search for similar embeddings.
37///
38/// It uses an in memory database with fast lookups for nearest neighbors and points within a certain distance.
39///
40/// # Example
41///
42/// ```rust, no_run
43/// # use kalosm_language::prelude::*;
44/// # use kalosm_language_model::*;
45/// # use rbert::*;
46/// # use std::collections::HashMap;
47/// # #[tokio::main]
48/// # async fn main() {
49/// // Create a good default Bert model for search
50/// let bert = Bert::new_for_search().await.unwrap();
51/// let sentences = [
52///     "Kalosm can be used to build local AI applications",
53///     "With private LLMs data never leaves your computer",
54///     "The quick brown fox jumps over the lazy dog",
55/// ];
56/// // Embed sentences into the vector space
57/// let embeddings = bert.embed_batch(sentences).await.unwrap();
58/// println!("embeddings {:?}", embeddings);
59///
60/// // Create a vector database from the embeddings along with a map between the embedding ids and the sentences
61/// let db = VectorDB::new().unwrap();
62/// let embeddings = db.add_embeddings(embeddings).unwrap();
63/// let embedding_id_to_sentence: HashMap<EmbeddingId, &str> =
64///     HashMap::from_iter(embeddings.into_iter().zip(sentences));
65///
66/// // Find the closest sentence to "What is Kalosm?"
67/// let query = "What is Kalosm?";
68/// // Embed the query into the vector space. We use `embed_query` instead of `embed` because some models embed queries differently than normal text.
69/// let embedding = bert.embed_query(query).await.unwrap();
70/// let closest = db.search(&embedding).run().unwrap();
71/// if let [closest] = closest.as_slice() {
72///     let distance = closest.distance;
73///     let text = embedding_id_to_sentence.get(&closest.value).unwrap();
74///     println!("distance: {distance}");
75///     println!("closest:  {text}");
76/// }
77/// # }
78/// ```
79#[doc(alias = "VectorDatabase")]
80#[doc(alias = "Vector Database")]
81pub struct VectorDB {
82    database: ArroyDatabase<DotProduct>,
83    metadata: Database<Str, SerdeJson<Vec<u32>>>,
84    env: heed::Env,
85    dim: AtomicUsize,
86}
87
88impl Default for VectorDB {
89    fn default() -> Self {
90        Self::new().unwrap()
91    }
92}
93
94impl VectorDB {
95    fn set_dim(&self, dim: usize) {
96        if dim == 0 {
97            panic!("Dimension cannot be 0");
98        }
99        self.dim.store(dim, std::sync::atomic::Ordering::Relaxed);
100    }
101
102    fn get_dim(&self) -> Result<usize, arroy::Error> {
103        let mut dims = self.dim.load(std::sync::atomic::Ordering::Relaxed);
104        if dims == 0 {
105            let rtxn = self.env.read_txn()?;
106            let reader = Reader::<DotProduct>::open(&rtxn, 0, self.database)?;
107            dims = reader.dimensions();
108            self.set_dim(dims);
109        }
110        Ok(dims)
111    }
112
113    /// Create a new temporary vector database.
114    #[tracing::instrument]
115    pub fn new() -> heed::Result<Self> {
116        let dir = tempfile::tempdir()?;
117
118        Self::new_at(dir.path())
119    }
120
121    /// Create a new vector database at the given path.
122    pub fn new_at(path: impl AsRef<std::path::Path>) -> heed::Result<Self> {
123        const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024;
124
125        std::fs::create_dir_all(&path)?;
126
127        let env = unsafe {
128            EnvOpenOptions::new()
129                .map_size(TWENTY_HUNDRED_MIB)
130                .open(path)
131        }?;
132
133        let mut wtxn = env.write_txn()?;
134        let db: ArroyDatabase<DotProduct> = env.create_database(&mut wtxn, None)?;
135        let metadata: Database<Str, SerdeJson<Vec<u32>>> = env.create_database(&mut wtxn, None)?;
136        wtxn.commit()?;
137
138        Ok(Self {
139            database: db,
140            metadata,
141            env,
142            dim: AtomicUsize::new(0),
143        })
144    }
145
146    fn take_id(&self, wtxn: &mut RwTxn) -> Result<EmbeddingId, heed::Error> {
147        if let Some(mut free) = self.metadata.get(wtxn, "free")? {
148            if let Some(id) = free.pop() {
149                self.metadata.put(wtxn, "free", &free)?;
150                return Ok(EmbeddingId(id));
151            }
152        }
153        match self.metadata.get(wtxn, "max")? {
154            Some(max) => {
155                let id = max[0];
156                self.metadata.put(wtxn, "max", &vec![id + 1])?;
157                Ok(EmbeddingId(id))
158            }
159            None => {
160                self.metadata.put(wtxn, "max", &vec![1])?;
161                Ok(EmbeddingId(0))
162            }
163        }
164    }
165
166    fn recycle_id(&self, id: EmbeddingId, wtxn: &mut RwTxn) -> Result<(), heed::Error> {
167        let mut free = self.metadata.get(wtxn, "free")?.unwrap_or_default();
168        free.push(id.0);
169        self.metadata.put(wtxn, "free", &free)?;
170
171        Ok(())
172    }
173
174    /// Get the underlying database.
175    pub fn raw(&self) -> (&ArroyDatabase<DotProduct>, &heed::Env) {
176        (&self.database, &self.env)
177    }
178
179    /// Clear the vector database.
180    pub async fn clear(&self) -> Result<(), arroy::Error> {
181        let mut wtxn = self.env.write_txn()?;
182        let dims = self.get_dim()?;
183        let writer = Writer::<DotProduct>::new(self.database, 0, dims);
184        writer.clear(&mut wtxn)?;
185
186        // Reset the ids
187        self.metadata.put(&mut wtxn, "max", &vec![0])?;
188        self.metadata.put(&mut wtxn, "free", &vec![])?;
189        wtxn.commit()?;
190
191        Ok(())
192    }
193
194    /// Rebuild the database.
195    pub fn rebuild(
196        &self,
197        writer: &mut Writer<DotProduct>,
198        wtxn: &mut RwTxn,
199    ) -> Result<(), arroy::Error> {
200        let mut rng = StdRng::from_entropy();
201        writer.builder(&mut rng).build(wtxn)?;
202
203        Ok(())
204    }
205
206    /// Remove an embedding from the vector database.
207    pub fn remove_embedding(&self, embedding_id: EmbeddingId) -> Result<(), arroy::Error> {
208        let dims = self.get_dim()?;
209
210        let mut wtxn = self.env.write_txn()?;
211
212        let mut writer = Writer::<DotProduct>::new(self.database, 0, dims);
213
214        writer.del_item(&mut wtxn, embedding_id.0)?;
215        self.recycle_id(embedding_id, &mut wtxn)?;
216
217        self.rebuild(&mut writer, &mut wtxn)?;
218
219        wtxn.commit()?;
220
221        Ok(())
222    }
223
224    /// Add a new embedding to the vector database.
225    ///
226    /// Note: Adding embeddings in a batch with [`VectorDB::add_embeddings`] will be faster.
227    pub fn add_embedding(&self, embedding: Embedding) -> Result<EmbeddingId, VectorDbError> {
228        let embedding = embedding.vector();
229
230        self.set_dim(embedding.len());
231
232        let mut wtxn = self.env.write_txn()?;
233
234        let mut writer = Writer::<DotProduct>::new(self.database, 0, embedding.len());
235
236        let id = self.take_id(&mut wtxn)?;
237
238        writer.add_item(&mut wtxn, id.0, embedding)?;
239
240        self.rebuild(&mut writer, &mut wtxn)?;
241
242        wtxn.commit()?;
243
244        Ok(id)
245    }
246
247    /// Add a new batch of embeddings to the vector database.
248    pub fn add_embeddings(
249        &self,
250        embedding: impl IntoIterator<Item = Embedding>,
251    ) -> Result<Vec<EmbeddingId>, VectorDbError> {
252        let mut embeddings = embedding
253            .into_iter()
254            .map(|e| e.vector().to_vec().into_boxed_slice());
255        let Some(first_embedding) = embeddings.next() else {
256            return Ok(Vec::new());
257        };
258        self.set_dim(first_embedding.len());
259
260        let mut wtxn = self.env.write_txn()?;
261        let mut writer = Writer::<DotProduct>::new(self.database, 0, first_embedding.len());
262
263        let mut ids: Vec<_> = Vec::with_capacity(embeddings.size_hint().0 + 1);
264
265        {
266            let first_id = self.take_id(&mut wtxn)?;
267            writer.add_item(&mut wtxn, first_id.0, &first_embedding)?;
268            ids.push(first_id);
269        }
270
271        for embedding in embeddings {
272            let id = self.take_id(&mut wtxn)?;
273            writer.add_item(&mut wtxn, id.0, &embedding)?;
274            ids.push(id);
275        }
276
277        self.rebuild(&mut writer, &mut wtxn)?;
278
279        wtxn.commit()?;
280
281        Ok(ids)
282    }
283
284    /// Get the embedding for an embedding id.
285    pub fn get_embedding(&self, embedding_id: EmbeddingId) -> Result<Embedding, VectorDbError> {
286        let rtxn = self.env.read_txn()?;
287        let reader = Reader::<DotProduct>::open(&rtxn, 0, self.database)?;
288
289        let embedding = reader
290            .item_vector(&rtxn, embedding_id.0)?
291            .ok_or_else(|| VectorDbError::EmbeddingNotFound(embedding_id))?;
292
293        Ok(Embedding::from(embedding))
294    }
295
296    /// Get the closest N embeddings to the given embedding.
297    pub fn search<'a>(&'a self, embedding: &'a Embedding) -> VectorDBSearchBuilder<'a> {
298        VectorDBSearchBuilder {
299            db: self,
300            embedding,
301            results: None,
302            filter: None,
303        }
304    }
305}
306
307/// A trait for anything that can be used to filter the results of a vector search.
308pub trait IntoVectorDbSearchFilter<M> {
309    /// Convert the filter into a set of candidates.
310    fn into_vector_db_search_filter(self, db: &VectorDB) -> Candidates;
311}
312
313impl IntoVectorDbSearchFilter<()> for Candidates {
314    fn into_vector_db_search_filter(self, _: &VectorDB) -> Candidates {
315        self
316    }
317}
318
319/// A marker type that allows kalosm to specialize the [`IntoVectorDbSearchFilter`] trait for iterators.
320pub struct IteratorMarker;
321
322impl<I> IntoVectorDbSearchFilter<IteratorMarker> for I
323where
324    I: IntoIterator<Item = EmbeddingId>,
325{
326    fn into_vector_db_search_filter(self, _: &VectorDB) -> Candidates {
327        let mut candidates = Candidates::new();
328        for id in self {
329            candidates.insert(id.0);
330        }
331        candidates
332    }
333}
334
335/// A marker type that allows kalosm to specialize the [`IntoVectorDbSearchFilter`] trait for closures.
336pub struct ClosureMarker;
337
338impl<I> IntoVectorDbSearchFilter<ClosureMarker> for I
339where
340    I: FnMut(Embedding) -> bool,
341{
342    fn into_vector_db_search_filter(mut self, db: &VectorDB) -> Candidates {
343        let mut candidates = Candidates::new();
344        let rtxn = match db.env.read_txn() {
345            Ok(rtxn) => rtxn,
346            Err(err) => {
347                tracing::error!("Error opening read transaction: {:?}", err);
348                return candidates;
349            }
350        };
351        let reader = match Reader::<DotProduct>::open(&rtxn, 0, db.database) {
352            Ok(reader) => reader,
353            Err(err) => {
354                tracing::error!("Error opening reader: {:?}", err);
355                return candidates;
356            }
357        };
358        for (key, tensor) in reader.iter(&rtxn).ok().into_iter().flatten().flatten() {
359            let embedding = Embedding::from(tensor);
360            if self(embedding) {
361                candidates.insert(key);
362            }
363        }
364        candidates
365    }
366}
367
368/// A builder for searching for embeddings in a vector database.
369pub struct VectorDBSearchBuilder<'a> {
370    db: &'a VectorDB,
371    embedding: &'a Embedding,
372    results: Option<usize>,
373    filter: Option<Candidates>,
374}
375
376impl VectorDBSearchBuilder<'_> {
377    /// Set the number of results to return. Defaults to 10.
378    pub fn with_results(mut self, results: usize) -> Self {
379        self.results = Some(results);
380        self
381    }
382
383    /// Set a filter to apply to the results. Only vectors that pass the filter will be returned.
384    pub fn with_filter<Marker>(
385        mut self,
386        filter: impl IntoVectorDbSearchFilter<Marker> + Send + Sync + 'static,
387    ) -> Self {
388        self.filter = Some(filter.into_vector_db_search_filter(self.db));
389        self
390    }
391
392    /// Run the search and return the results.
393    pub fn run(self) -> Result<Vec<VectorDBSearchResult>, VectorDbError> {
394        let rtxn = self.db.env.read_txn()?;
395        let reader = Reader::<DotProduct>::open(&rtxn, 0, self.db.database)?;
396
397        let vector = self.embedding.vector();
398        let mut query = reader.nns(self.results.unwrap_or(10));
399        if let Some(filter) = self.filter.as_ref() {
400            query.candidates(filter);
401        }
402        let arroy_results = query.by_vector(&rtxn, vector)?;
403
404        Ok(arroy_results
405            .into_iter()
406            .map(|(id, distance)| {
407                let value = EmbeddingId(id);
408                VectorDBSearchResult { distance, value }
409            })
410            .collect::<Vec<_>>())
411    }
412}
413
414/// A resulting point from a search.
415#[derive(Debug, Clone, PartialEq)]
416pub struct VectorDBSearchResult {
417    /// The distance from the searched point.
418    pub distance: f32,
419    /// The value of the point.
420    pub value: EmbeddingId,
421}
422
423/// A unique identifier for an embedding. If you delete an embedding, the id will be recycled.
424#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
425pub struct EmbeddingId(pub u32);
426
427#[tokio::test]
428async fn test_vector_db_get_closest() {
429    let db: VectorDB = VectorDB::new().unwrap();
430    let first_vector = Embedding::from([1.0, 2.0, 3.0]);
431    let second_embedding = Embedding::from([-1.0, 2.0, 3.0]);
432    let id1 = db.add_embedding(first_vector.clone()).unwrap();
433    let id2 = db.add_embedding(second_embedding.clone()).unwrap();
434    assert_eq!(
435        db.search(&first_vector)
436            .with_results(1)
437            .run()
438            .unwrap()
439            .iter()
440            .map(|r| r.value)
441            .collect::<Vec<_>>(),
442        vec![id1]
443    );
444    assert_eq!(
445        db.search(&second_embedding)
446            .with_results(1)
447            .run()
448            .unwrap()
449            .iter()
450            .map(|r| r.value)
451            .collect::<Vec<_>>(),
452        vec![id2]
453    );
454    let third_embedding = Embedding::from([1.0, 0.0, 0.0]);
455    assert_eq!(
456        db.search(&third_embedding)
457            .with_results(1)
458            .run()
459            .unwrap()
460            .iter()
461            .map(|r| r.value)
462            .collect::<Vec<_>>(),
463        vec![id1]
464    );
465    assert_eq!(
466        db.search(&third_embedding)
467            .with_filter(|vector: Embedding| vector.vector()[0] < 0.0)
468            .run()
469            .unwrap()
470            .iter()
471            .map(|r| r.value)
472            .collect::<Vec<_>>(),
473        vec![id2]
474    );
475}