Skip to main content

spark_bert/
api.rs

1use anyhow::Result;
2use candle_core::Device;
3
4use crate::{
5    args::Args,
6    embs::{convert_to_flatten_vec, Bert},
7    inverted_index::InvertedIndex,
8    score::calculate_max_sim,
9    vector_vocabulary::VectorVocabulary,
10};
11
12pub struct SparkBert {
13    vector_vocabulary: VectorVocabulary,
14    inverted_index: InvertedIndex,
15    bert: Bert,
16    config: Config,
17}
18pub struct Config {
19    pub use_ram_index: bool,
20    pub device: Device,
21    pub index_n_neighbors: usize,
22}
23
24impl SparkBert {
25    pub fn new(config: Config) -> Result<Self> {
26        let vector_vocabulary = VectorVocabulary::build()?;
27        println!(
28            "Vector vocabulary size: {}",
29            vector_vocabulary.get_num_tokens()
30        );
31        let inverted_index = InvertedIndex::open(config.use_ram_index)?;
32        let args = Args {
33            cpu: config.device.is_cpu(),
34            tracing: false,
35            model_id: Option::None,
36            revision: Option::None,
37            use_pth: false,
38            normalize_embeddings: true,
39            approximate_gelu: false,
40        };
41        let bert = Bert::build(args)?;
42        Ok(Self {
43            vector_vocabulary,
44            inverted_index,
45            bert,
46            config,
47        })
48    }
49
50    pub fn search(
51        &mut self,
52        query: &str,
53        search_n_neighbors: usize,
54        top_k: usize,
55    ) -> Result<Vec<(u64, f64)>> {
56        let query_embs = self.bert.calc_embs(vec![query], false)?;
57        let flat_query_embs = convert_to_flatten_vec(&query_embs)?;
58        let (tokens, _) =
59            self.vector_vocabulary
60                .find_tokens(&flat_query_embs, search_n_neighbors, false)?;
61        let doc_id_score_pairs = self.inverted_index.search(None, tokens.as_slice(), top_k)?;
62        Ok(doc_id_score_pairs)
63    }
64
65    pub fn index<I>(&mut self, docs: I, filter_stop_words: bool) -> Result<()>
66    where
67        I: IntoIterator<Item = (u64, String)>,
68    {
69        let d = self.vector_vocabulary.get_embedding_dims() as usize;
70        for (doc_id, text) in docs {
71            let doc_embs = &self.bert.calc_embs(vec![text.as_str()], false)?;
72            let flat_doc_embs = convert_to_flatten_vec(doc_embs)?;
73            let (tokens, token_embs) = self.vector_vocabulary.find_tokens(
74                &flat_doc_embs,
75                self.config.index_n_neighbors,
76                true,
77            )?;
78            let scores =
79                calculate_max_sim(flat_doc_embs, token_embs.unwrap(), &self.config.device, d)?;
80            self.inverted_index.index(
81                doc_id,
82                tokens.iter().map(|&it| it.to_owned()).collect(),
83                scores,
84            );
85        }
86        self.inverted_index.finalize(filter_stop_words)?;
87        Ok(())
88    }
89
90    pub fn get_num_docs(&self) -> u64 {
91        self.inverted_index.get_num_docs()
92    }
93}