spark-bert 0.1.1

Hybrid vector search using an inverted index and BERT embeddings
Documentation
use anyhow::Result;
use candle_core::Device;

use crate::{
    args::Args,
    embs::{convert_to_flatten_vec, Bert},
    inverted_index::InvertedIndex,
    score::calculate_max_sim,
    vector_vocabulary::VectorVocabulary,
};

pub struct SparkBert {
    vector_vocabulary: VectorVocabulary,
    inverted_index: InvertedIndex,
    bert: Bert,
    config: Config,
}
pub struct Config {
    pub use_ram_index: bool,
    pub device: Device,
    pub index_n_neighbors: usize,
}

impl SparkBert {
    pub fn new(config: Config) -> Result<Self> {
        let vector_vocabulary = VectorVocabulary::build()?;
        println!(
            "Vector vocabulary size: {}",
            vector_vocabulary.get_num_tokens()
        );
        let inverted_index = InvertedIndex::open(config.use_ram_index)?;
        let args = Args {
            cpu: config.device.is_cpu(),
            tracing: false,
            model_id: Option::None,
            revision: Option::None,
            use_pth: false,
            normalize_embeddings: true,
            approximate_gelu: false,
        };
        let bert = Bert::build(args)?;
        Ok(Self {
            vector_vocabulary,
            inverted_index,
            bert,
            config,
        })
    }

    pub fn search(
        &mut self,
        query: &str,
        search_n_neighbors: usize,
        top_k: usize,
    ) -> Result<Vec<(u64, f64)>> {
        let query_embs = self.bert.calc_embs(vec![query], false)?;
        let flat_query_embs = convert_to_flatten_vec(&query_embs)?;
        let (tokens, _) =
            self.vector_vocabulary
                .find_tokens(&flat_query_embs, search_n_neighbors, false)?;
        let doc_id_score_pairs = self.inverted_index.search(None, tokens.as_slice(), top_k)?;
        Ok(doc_id_score_pairs)
    }

    pub fn index<I>(&mut self, docs: I, filter_stop_words: bool) -> Result<()>
    where
        I: IntoIterator<Item = (u64, String)>,
    {
        let d = self.vector_vocabulary.get_embedding_dims() as usize;
        for (doc_id, text) in docs {
            let doc_embs = &self.bert.calc_embs(vec![text.as_str()], false)?;
            let flat_doc_embs = convert_to_flatten_vec(doc_embs)?;
            let (tokens, token_embs) = self.vector_vocabulary.find_tokens(
                &flat_doc_embs,
                self.config.index_n_neighbors,
                true,
            )?;
            let scores =
                calculate_max_sim(flat_doc_embs, token_embs.unwrap(), &self.config.device, d)?;
            self.inverted_index.index(
                doc_id,
                tokens.iter().map(|&it| it.to_owned()).collect(),
                scores,
            );
        }
        self.inverted_index.finalize(filter_stop_words)?;
        Ok(())
    }

    pub fn get_num_docs(&self) -> u64 {
        self.inverted_index.get_num_docs()
    }
}