1use anyhow::Result;
2use candle_core::Device;
3
4use crate::{
5 args::Args,
6 embs::{convert_to_flatten_vec, Bert},
7 inverted_index::InvertedIndex,
8 score::calculate_max_sim,
9 vector_vocabulary::VectorVocabulary,
10};
11
12pub struct SparkBert {
13 vector_vocabulary: VectorVocabulary,
14 inverted_index: InvertedIndex,
15 bert: Bert,
16 config: Config,
17}
18pub struct Config {
19 pub use_ram_index: bool,
20 pub device: Device,
21 pub index_n_neighbors: usize,
22}
23
24impl SparkBert {
25 pub fn new(config: Config) -> Result<Self> {
26 let vector_vocabulary = VectorVocabulary::build()?;
27 println!(
28 "Vector vocabulary size: {}",
29 vector_vocabulary.get_num_tokens()
30 );
31 let inverted_index = InvertedIndex::open(config.use_ram_index)?;
32 let args = Args {
33 cpu: config.device.is_cpu(),
34 tracing: false,
35 model_id: Option::None,
36 revision: Option::None,
37 use_pth: false,
38 normalize_embeddings: true,
39 approximate_gelu: false,
40 };
41 let bert = Bert::build(args)?;
42 Ok(Self {
43 vector_vocabulary,
44 inverted_index,
45 bert,
46 config,
47 })
48 }
49
50 pub fn search(
51 &mut self,
52 query: &str,
53 search_n_neighbors: usize,
54 top_k: usize,
55 ) -> Result<Vec<(u64, f64)>> {
56 let query_embs = self.bert.calc_embs(vec![query], false)?;
57 let flat_query_embs = convert_to_flatten_vec(&query_embs)?;
58 let (tokens, _) =
59 self.vector_vocabulary
60 .find_tokens(&flat_query_embs, search_n_neighbors, false)?;
61 let doc_id_score_pairs = self.inverted_index.search(None, tokens.as_slice(), top_k)?;
62 Ok(doc_id_score_pairs)
63 }
64
65 pub fn index<I>(&mut self, docs: I, filter_stop_words: bool) -> Result<()>
66 where
67 I: IntoIterator<Item = (u64, String)>,
68 {
69 let d = self.vector_vocabulary.get_embedding_dims() as usize;
70 for (doc_id, text) in docs {
71 let doc_embs = &self.bert.calc_embs(vec![text.as_str()], false)?;
72 let flat_doc_embs = convert_to_flatten_vec(doc_embs)?;
73 let (tokens, token_embs) = self.vector_vocabulary.find_tokens(
74 &flat_doc_embs,
75 self.config.index_n_neighbors,
76 true,
77 )?;
78 let scores =
79 calculate_max_sim(flat_doc_embs, token_embs.unwrap(), &self.config.device, d)?;
80 self.inverted_index.index(
81 doc_id,
82 tokens.iter().map(|&it| it.to_owned()).collect(),
83 scores,
84 );
85 }
86 self.inverted_index.finalize(filter_stop_words)?;
87 Ok(())
88 }
89
90 pub fn get_num_docs(&self) -> u64 {
91 self.inverted_index.get_num_docs()
92 }
93}