contextdb-vector 1.0.0

Vector executor for contextdb — cosine similarity, HNSW indexing, pre-filtered search
Documentation
use crate::{HnswIndex, store::VectorStore};
use contextdb_core::*;
use contextdb_tx::{TxManager, WriteSetApplicator};
use parking_lot::RwLock;
use roaring::RoaringTreemap;
use std::collections::HashSet;
use std::sync::{Arc, OnceLock};

const HNSW_THRESHOLD: usize = 1000;

pub struct MemVectorExecutor<S: WriteSetApplicator> {
    store: Arc<VectorStore>,
    tx_mgr: Arc<TxManager<S>>,
    accountant: Arc<MemoryAccountant>,
}

impl<S: WriteSetApplicator> MemVectorExecutor<S> {
    pub fn new(
        store: Arc<VectorStore>,
        tx_mgr: Arc<TxManager<S>>,
        hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
    ) -> Self {
        Self::new_with_accountant(store, tx_mgr, hnsw, Arc::new(MemoryAccountant::no_limit()))
    }

    pub fn new_with_accountant(
        store: Arc<VectorStore>,
        tx_mgr: Arc<TxManager<S>>,
        _hnsw: Arc<OnceLock<RwLock<Option<HnswIndex>>>>,
        accountant: Arc<MemoryAccountant>,
    ) -> Self {
        Self {
            store,
            tx_mgr,
            accountant,
        }
    }

    fn brute_force_search(
        &self,
        index: &VectorIndexRef,
        query: &[f32],
        k: usize,
        candidates: Option<&RoaringTreemap>,
        snapshot: SnapshotId,
    ) -> Result<Vec<(RowId, f32)>> {
        let state = self.store.state(index)?;
        if query.len() != state.dimension() {
            return Err(Error::VectorIndexDimensionMismatch {
                index: index.clone(),
                expected: state.dimension(),
                actual: query.len(),
            });
        }

        let mut scored: Vec<(RowId, f32)> = state.with_entries(|entries| {
            let mut scored = Vec::new();
            for entry in entries {
                if !entry.visible_at(snapshot) {
                    continue;
                }

                if let Some(cands) = candidates
                    && !cands.contains(entry.row_id.0)
                {
                    continue;
                }

                let sim = entry.vector.cosine_similarity(query);
                scored.push((entry.row_id, sim));
            }
            scored
        });

        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
        scored.truncate(k);
        Ok(scored)
    }

    fn build_hnsw_from_store(&self, index: &VectorIndexRef) -> Option<HnswIndex> {
        let _build_guard = self.store.build_lock();
        let state = self.store.try_state(index)?;
        let dim = state.dimension();
        let entry_count = state.entry_count();
        let estimated_bytes = estimate_hnsw_bytes(entry_count, dim, state.quantization());
        if self
            .accountant
            .try_allocate_for(
                estimated_bytes,
                "vector_index",
                &format!("build_hnsw@{}.{}", index.table, index.column),
                "Reduce vector volume or raise MEMORY_LIMIT so the HNSW index can be built.",
            )
            .is_err()
        {
            return None;
        }

        let built = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            state.with_entries(|entries| HnswIndex::new(entries, dim, state.quantization()))
        }))
        .ok();
        if built.is_none() {
            self.accountant.release(estimated_bytes);
        } else {
            state.set_hnsw_bytes(estimated_bytes);
        }
        built
    }
}

impl<S: WriteSetApplicator> VectorExecutor for MemVectorExecutor<S> {
    fn search(
        &self,
        index: VectorIndexRef,
        query: &[f32],
        k: usize,
        candidates: Option<&RoaringTreemap>,
        snapshot: SnapshotId,
    ) -> Result<Vec<(RowId, f32)>> {
        if k == 0 {
            return Ok(Vec::new());
        }
        let Some(state) = self.store.try_state(&index) else {
            return Err(Error::UnknownVectorIndex { index });
        };
        if query.len() != state.dimension() {
            return Err(Error::VectorIndexDimensionMismatch {
                index,
                expected: state.dimension(),
                actual: query.len(),
            });
        }
        if state.entry_count() == 0 {
            return Ok(Vec::new());
        }

        let use_hnsw = state.entry_count() >= HNSW_THRESHOLD;
        if use_hnsw {
            let lock = state
                .hnsw()
                .get_or_init(|| RwLock::new(self.build_hnsw_from_store(&index)));

            {
                let mut guard = lock.write();
                if guard.is_none() {
                    *guard = self.build_hnsw_from_store(&index);
                }
            }

            let guard = lock.read();
            if let Some(hnsw) = guard.as_ref() {
                let raw_candidates = hnsw.search(&index, query, k)?;
                let raw_candidate_count = raw_candidates.len();

                if candidates.is_some() && raw_candidates.len() < hnsw.len() {
                    return self.brute_force_search(&index, query, k, candidates, snapshot);
                }

                let supplement_missing = raw_candidate_count.saturating_add(64) >= hnsw.len();
                let raw_row_ids = if supplement_missing {
                    raw_candidates
                        .iter()
                        .map(|(row_id, _)| *row_id)
                        .collect::<HashSet<_>>()
                } else {
                    HashSet::new()
                };
                let mut visible = state.with_entries(|entries| {
                    let mut visible = raw_candidates
                        .into_iter()
                        .filter_map(|(rid, _)| {
                            entries
                                .iter()
                                .find(|entry| entry.row_id == rid && entry.visible_at(snapshot))
                                .and_then(|entry| {
                                    if let Some(cands) = candidates
                                        && !cands.contains(entry.row_id.0)
                                    {
                                        return None;
                                    }
                                    Some((entry.row_id, entry.vector.cosine_similarity(query)))
                                })
                        })
                        .collect::<Vec<_>>();
                    if supplement_missing {
                        for entry in entries {
                            if raw_row_ids.contains(&entry.row_id) || !entry.visible_at(snapshot) {
                                continue;
                            }
                            visible.push((entry.row_id, entry.vector.cosine_similarity(query)));
                        }
                    }
                    visible
                });

                visible.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
                if visible.len() < k && raw_candidate_count < hnsw.len() {
                    return self.brute_force_search(&index, query, k, candidates, snapshot);
                }
                visible.truncate(k);
                return Ok(visible);
            }
        }

        self.brute_force_search(&index, query, k, candidates, snapshot)
    }

    fn insert_vector(
        &self,
        tx: TxId,
        index: VectorIndexRef,
        row_id: RowId,
        vector: Vec<f32>,
    ) -> Result<()> {
        self.store.validate_vector(&index, vector.len())?;
        let entry = VectorEntry {
            index,
            row_id,
            vector,
            created_tx: tx,
            deleted_tx: None,
            lsn: contextdb_core::Lsn(0),
        };

        self.tx_mgr.with_write_set(tx, |ws| {
            ws.vector_inserts.push(entry);
        })?;

        Ok(())
    }

    fn delete_vector(&self, tx: TxId, index: VectorIndexRef, row_id: RowId) -> Result<()> {
        self.store.state(&index)?;
        self.tx_mgr.with_write_set(tx, |ws| {
            ws.vector_deletes.push((index, row_id, tx));
        })?;

        Ok(())
    }
}

fn estimate_hnsw_bytes(
    entry_count: usize,
    dimension: usize,
    quantization: VectorQuantization,
) -> usize {
    let entry_bytes = match quantization {
        VectorQuantization::F32 => quantization.storage_bytes(dimension),
        VectorQuantization::SQ8 => dimension.saturating_add(12),
        VectorQuantization::SQ4 => dimension.div_ceil(2).saturating_add(12),
    };
    entry_count.saturating_mul(entry_bytes).saturating_mul(3)
}