use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::Mutex;
use arroy::distances::Euclidean;
use arroy::{Database as ArroyDatabase, Reader, Writer};
use heed::EnvOpenOptions;
use kalosm_language_model::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use serde::{Deserialize, Serialize};
pub struct VectorDB<S = UnknownVectorSpace> {
database: ArroyDatabase<Euclidean>,
env: heed::Env,
max_id: Mutex<EmbeddingId>,
recycled_ids: Mutex<Vec<EmbeddingId>>,
dim: AtomicUsize,
_phantom: std::marker::PhantomData<S>,
}
impl<S: VectorSpace + Sync> Default for VectorDB<S> {
fn default() -> Self {
Self::new().unwrap()
}
}
impl<S: VectorSpace + Sync> VectorDB<S> {
fn set_dim(&self, dim: usize) {
if dim == 0 {
panic!("Dimension cannot be 0");
}
self.dim.store(dim, std::sync::atomic::Ordering::Relaxed);
}
fn get_dim(&self) -> anyhow::Result<usize> {
let mut dims = self.dim.load(std::sync::atomic::Ordering::Relaxed);
if dims == 0 {
let rtxn = self.env.read_txn()?;
let reader = Reader::<Euclidean>::open(&rtxn, 0, self.database)?;
dims = reader.dimensions();
self.set_dim(dims);
}
Ok(dims)
}
#[tracing::instrument]
pub fn new() -> anyhow::Result<Self> {
let dir = tempfile::tempdir()?;
Self::new_at(dir.path())
}
pub fn new_at(path: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024;
std::fs::create_dir_all(&path)?;
let env = unsafe {
EnvOpenOptions::new()
.map_size(TWENTY_HUNDRED_MIB)
.open(path)?
};
let mut wtxn = env.write_txn()?;
let db: ArroyDatabase<Euclidean> = env.create_database(&mut wtxn, None)?;
wtxn.commit()?;
Ok(Self {
database: db,
env,
max_id: Mutex::new(EmbeddingId(0)),
recycled_ids: Mutex::new(Vec::new()),
dim: AtomicUsize::new(0),
_phantom: std::marker::PhantomData,
})
}
fn take_id(&self) -> EmbeddingId {
self.recycled_ids.lock().unwrap().pop().unwrap_or_else(|| {
let mut locked = self.max_id.lock().unwrap();
let id = *locked;
locked.0 += 1;
id
})
}
fn recycle_id(&self, id: EmbeddingId) {
self.recycled_ids.lock().unwrap().push(id);
}
pub fn raw(&self) -> (&ArroyDatabase<Euclidean>, &heed::Env) {
(&self.database, &self.env)
}
pub fn remove_embedding(&self, embedding_id: EmbeddingId) -> anyhow::Result<()> {
let dims = self.get_dim()?;
let mut wtxn = self.env.write_txn()?;
let writer = Writer::<Euclidean>::new(self.database, 0, dims)?;
writer.del_item(&mut wtxn, embedding_id.0)?;
self.recycle_id(embedding_id);
let mut rng = StdRng::from_entropy();
writer.build(&mut wtxn, &mut rng, None)?;
wtxn.commit()?;
Ok(())
}
pub fn add_embedding(&self, embedding: Embedding<S>) -> anyhow::Result<EmbeddingId> {
let embedding = embedding.vector().to_vec1()?;
self.set_dim(embedding.len());
let mut wtxn = self.env.write_txn()?;
let writer = Writer::<Euclidean>::new(self.database, 0, embedding.len())?;
let id = self.take_id();
writer.add_item(&mut wtxn, id.0, &embedding)?;
let mut rng = StdRng::from_entropy();
writer.build(&mut wtxn, &mut rng, None)?;
wtxn.commit()?;
Ok(id)
}
pub fn add_embeddings(
&self,
embedding: impl IntoIterator<Item = Embedding<S>>,
) -> anyhow::Result<Vec<EmbeddingId>> {
let mut embeddings = embedding.into_iter().map(|e| e.vector().to_vec1());
let first_embedding = match embeddings.next() {
Some(e) => e?,
None => return Ok(Vec::new()),
};
self.set_dim(first_embedding.len());
let mut wtxn = self.env.write_txn()?;
let writer = Writer::<Euclidean>::new(self.database, 0, first_embedding.len())?;
let mut ids: Vec<_> = Vec::with_capacity(embeddings.size_hint().0 + 1);
{
let first_id = self.take_id();
writer.add_item(&mut wtxn, first_id.0, &first_embedding)?;
ids.push(first_id);
}
for embedding in embeddings {
let id = self.take_id();
writer.add_item(&mut wtxn, id.0, &embedding?)?;
ids.push(id);
}
let mut rng = StdRng::from_entropy();
writer.build(&mut wtxn, &mut rng, None)?;
wtxn.commit()?;
Ok(ids)
}
pub fn get_closest(
&self,
embedding: Embedding<S>,
n: usize,
) -> anyhow::Result<Vec<VectorDBSearchResult>> {
let rtxn = self.env.read_txn()?;
let reader = Reader::<Euclidean>::open(&rtxn, 0, self.database)?;
let vector = embedding.vector().to_vec1()?;
let arroy_results = reader.nns_by_vector(&rtxn, &vector, n, None, None)?;
Ok(arroy_results
.into_iter()
.map(|(id, distance)| {
let value = EmbeddingId(id);
VectorDBSearchResult { distance, value }
})
.collect::<Vec<_>>())
}
}
#[derive(Debug, Clone)]
pub struct VectorDBSearchResult {
pub distance: f32,
pub value: EmbeddingId,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct EmbeddingId(pub u32);