spark-bert 0.1.1

Hybrid vector search using an inverted index and BERT embeddings
Documentation
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;

use anyhow::{anyhow, Result};
use faiss::index::{IndexImpl, SearchResult};
use faiss::{read_index, Idx, Index};
use hf_hub::api::sync::Api;
use hf_hub::Repo;
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;

pub struct VectorVocabulary {
    vector_index: IndexImpl,
    faiss_idx_to_token: HashMap<String, String>,
}

impl VectorVocabulary {
    pub fn build() -> Result<Self> {
        let repo = Repo::model(
            "viacheslav-dobrynin/spark-bert-msmarco-all-MiniLM-L6-v2-vector-vocab".to_owned(),
        );
        let api = Api::new()?.repo(repo);
        let faiss_idx_to_token_path = api
            .get("faiss_idx_to_token.json")?
            .into_os_string()
            .into_string()
            .map_err(|path| anyhow!("cache path is not valid UTF-8: {:?}", path))?;
        let vector_vocab_path = api
            .get("vector_vocab.hnsw.faiss")?
            .into_os_string()
            .into_string()
            .map_err(|path| anyhow!("cache path is not valid UTF-8: {:?}", path))?;
        let vector_index = read_index(&vector_vocab_path)?;
        let faiss_idx_to_token: HashMap<String, String> =
            Self::load_faiss_idx_to_token(&faiss_idx_to_token_path)?;
        Ok(Self {
            vector_index,
            faiss_idx_to_token,
        })
    }

    fn load_faiss_idx_to_token(json_path: &str) -> anyhow::Result<HashMap<String, String>> {
        let file = File::open(json_path)?;
        let reader = BufReader::new(file);
        let faiss_idx_to_token: HashMap<String, String> = serde_json::from_reader(reader)?;
        anyhow::Ok(faiss_idx_to_token)
    }

    pub fn get_num_tokens(&self) -> u64 {
        self.vector_index.ntotal()
    }

    pub fn get_embedding_dims(&self) -> u32 {
        self.vector_index.d()
    }

    pub fn find_tokens(
        &mut self,
        query_embs: &[f32],
        n_neighbors: usize,
        with_embs: bool,
    ) -> Result<(Vec<&str>, Option<Vec<f32>>)> {
        let SearchResult {
            distances: _,
            labels,
        } = self.vector_index.search(query_embs, n_neighbors)?;
        let labels = unique_labels(&labels);
        let tokens: Vec<&str> = labels
            .iter()
            .map(|idx| {
                let idx = idx.get().unwrap().to_string();
                self.faiss_idx_to_token
                    .get(&idx)
                    .map(String::as_str)
                    .unwrap()
            })
            .collect();
        let token_embs = if with_embs {
            Some(reconstruct_batch(&self.vector_index, &labels)?)
        } else {
            None
        };
        Ok((tokens, token_embs))
    }
}

pub fn reconstruct_batch<T>(index: &T, labels: &[faiss::Idx]) -> anyhow::Result<Vec<f32>>
where
    T: Index + Sync,
{
    let d = index.d() as usize;
    let batch = labels.len();
    let mut flat_embs = vec![0f32; batch * d];
    debug_assert_eq!(flat_embs.len(), labels.len() * d);
    flat_embs
        .par_chunks_mut(d)
        .enumerate()
        .try_for_each(|(i, chunk)| {
            let idx = labels[i];
            index.reconstruct(idx, chunk).map_err(anyhow::Error::from)
        })?;
    anyhow::Ok(flat_embs)
}

pub fn unique_labels(labels: &[Idx]) -> Vec<Idx> {
    let mut unique_ids: Vec<u64> = labels.iter().filter_map(|idx| idx.get()).collect();
    unique_ids.sort_unstable();
    unique_ids.dedup();
    unique_ids.into_iter().map(Idx::new).collect()
}

#[cfg(test)]
mod tests {
    use super::*;
    use faiss::error::Result as FaissResult;
    use faiss::{Idx, MetricType};

    #[test]
    fn should_reconstruct_batch_of_embs() {
        let mock = MockIndex {
            vecs: vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]],
        };
        let labels = [Idx::new(0), Idx::new(1)];

        let embs = reconstruct_batch(&mock, &labels).unwrap();

        assert_eq!(embs, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
    }

    #[test]
    fn should_return_unique_labels() {
        let labels = [
            Idx::new(2),
            Idx::new(1),
            Idx::new(2),
            Idx::new(3),
            Idx::new(1),
        ];

        let uniques = unique_labels(&labels);

        assert_eq!(uniques, vec![Idx::new(1), Idx::new(2), Idx::new(3)]);
    }

    struct MockIndex {
        vecs: Vec<Vec<f32>>,
    }

    impl faiss::Index for MockIndex {
        fn d(&self) -> u32 {
            self.vecs[0].len() as u32
        }

        fn reconstruct(&self, idx: Idx, dest: &mut [f32]) -> FaissResult<()> {
            dest.copy_from_slice(&self.vecs[idx.get().unwrap() as usize]);
            Ok(())
        }

        fn is_trained(&self) -> bool {
            todo!()
        }

        fn ntotal(&self) -> u64 {
            todo!()
        }

        fn metric_type(&self) -> MetricType {
            todo!()
        }

        fn add(&mut self, x: &[f32]) -> FaissResult<()> {
            let _ = x;
            todo!()
        }

        fn add_with_ids(&mut self, x: &[f32], xids: &[Idx]) -> FaissResult<()> {
            let _ = xids;
            let _ = x;
            todo!()
        }

        fn train(&mut self, x: &[f32]) -> FaissResult<()> {
            let _ = x;
            todo!()
        }

        fn assign(&mut self, q: &[f32], k: usize) -> FaissResult<faiss::index::AssignSearchResult> {
            let _ = k;
            let _ = q;
            todo!()
        }

        fn search(&mut self, q: &[f32], k: usize) -> FaissResult<faiss::index::SearchResult> {
            let _ = k;
            let _ = q;
            todo!()
        }

        fn range_search(
            &mut self,
            q: &[f32],
            radius: f32,
        ) -> FaissResult<faiss::index::RangeSearchResult> {
            let _ = radius;
            let _ = q;
            todo!()
        }

        fn reconstruct_n(
            &self,
            first_key: Idx,
            count: usize,
            output: &mut [f32],
        ) -> FaissResult<()> {
            let _ = output;
            let _ = count;
            let _ = first_key;
            todo!()
        }

        fn reset(&mut self) -> FaissResult<()> {
            todo!()
        }

        fn remove_ids(&mut self, sel: &faiss::selector::IdSelector) -> FaissResult<usize> {
            let _ = sel;
            todo!()
        }

        fn verbose(&self) -> bool {
            todo!()
        }

        fn set_verbose(&mut self, value: bool) {
            let _ = value;
            todo!()
        }
    }
}