spark_bert/
inverted_index.rs1use std::{
2 collections::{HashMap, HashSet},
3 fs,
4 path::PathBuf,
5};
6
7use anyhow::{Context, Result};
8use float8::F8E4M3;
9use tantivy::{
10 directory::MmapDirectory,
11 query::{BooleanQuery, Query},
12 schema::{
13 Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, Value, FAST, STORED,
14 },
15 Index, IndexReader, IndexWriter, ReloadPolicy, Searcher, TantivyDocument, Term,
16};
17
18use crate::{directory::ram_directory_from_mmap_dir, tf_term_query::TfTermQuery};
19
20const MAX_DF_RATIO: f32 = 0.15;
21
22pub struct InvertedIndex {
23 index: Index,
24 writer: IndexWriter,
25 pub reader: IndexReader,
26 token_cluster_id: Field,
27 doc_id: Field,
28 pending: HashMap<u64, Vec<(String, f32)>>,
29}
30
31impl InvertedIndex {
32 pub fn open(use_ram_index: bool) -> Result<Self> {
33 let directory_path = Self::default_directory_path();
34 let (schema, token_cluster_id, doc_id) = Self::build_schema()?;
35 let index = if use_ram_index {
36 if directory_path.exists() {
37 let ram_directory = ram_directory_from_mmap_dir(&directory_path)?;
38 Index::open(ram_directory)?
39 } else {
40 Index::create_in_ram(schema)
41 }
42 } else {
43 fs::create_dir_all(&directory_path)?;
44 let directory = MmapDirectory::open(&directory_path)?;
45 Index::open_or_create(directory, schema)?
46 };
47 let memory_budget_in_bytes = 500_000_000; let writer = index.writer(memory_budget_in_bytes)?;
49 let reader = Self::build_reader(&index)?;
50 Ok(Self {
51 index,
52 writer,
53 reader,
54 token_cluster_id,
55 doc_id,
56 pending: HashMap::new(),
57 })
58 }
59
60 fn default_directory_path() -> PathBuf {
61 std::env::var_os("SPARKBERT_INVERTED_INDEX_DIR")
62 .map(PathBuf::from)
63 .context("Please set SPARKBERT_INVERTED_INDEX_DIR env variable")
64 .unwrap()
65 }
66
67 fn build_schema() -> Result<(Schema, Field, Field)> {
68 let mut schema_builder = Schema::builder();
69 let tok_opts = TextOptions::default().set_indexing_options(
70 TextFieldIndexing::default()
71 .set_tokenizer("raw")
72 .set_index_option(IndexRecordOption::WithFreqs),
73 );
74 let token_cluster_id = schema_builder.add_text_field("token", tok_opts);
75 let doc_id = schema_builder.add_u64_field("doc_id", FAST | STORED);
76 let schema = schema_builder.build();
77 Ok((schema, token_cluster_id, doc_id))
78 }
79
80 pub fn index(&mut self, doc_id: u64, tokens: Vec<String>, scores: Vec<f32>) {
81 debug_assert_eq!(tokens.len(), scores.len());
82 let doc_entry = self.pending.entry(doc_id).or_default();
83 for (token, score) in tokens.into_iter().zip(scores.into_iter()) {
84 doc_entry.push((token, score));
85 }
86 }
87
88 pub fn finalize(&mut self, filter_stop_words: bool) -> Result<()> {
90 let stop_words = if filter_stop_words {
91 self.prepare_stop_words()
92 } else {
93 HashSet::new()
94 };
95 for (&doc_id, token_score_pairs) in self.pending.iter() {
96 let mut doc = TantivyDocument::new();
97 doc.add_u64(self.doc_id, doc_id);
98 let mut set = false;
99 for (token, score) in token_score_pairs {
100 if stop_words.contains(token) {
101 continue;
102 }
103 if *score < 22.7136 {
105 continue;
106 }
107 let reps = F8E4M3::from_f32(*score).to_bits();
109 if reps == 0 {
110 continue;
111 }
112 set = true;
113 for _ in 0..reps {
115 doc.add_text(self.token_cluster_id, token);
116 }
117 }
118 if set {
119 self.writer.add_document(doc)?;
120 } else {
121 panic!("adjust hyperparams, no tokens were added to doc")
122 }
123 }
124 self.pending.clear();
125 self.writer.commit()?;
126 self.writer
127 .merge(&self.index.searchable_segment_ids()?)
128 .wait()?;
129 self.reader.reload()?;
130 Ok(())
131 }
132
133 fn prepare_stop_words(&self) -> HashSet<&String> {
134 let mut token_to_doc_count = HashMap::new();
135 for (_, token_score_pairs) in self.pending.iter() {
136 let mut seen = HashSet::new();
137 for (token, _) in token_score_pairs {
138 if seen.insert(token) {
139 *token_to_doc_count.entry(token).or_insert(0) += 1;
140 }
141 }
142 }
143 let total_docs = self.pending.len() as f32;
144 token_to_doc_count
145 .into_iter()
146 .filter(|(_, doc_count)| (*doc_count as f32 / total_docs) >= MAX_DF_RATIO)
147 .map(|(token, _)| token)
148 .collect()
149 }
150
151 fn build_reader(index: &Index) -> Result<IndexReader> {
152 let reader = index
153 .reader_builder()
154 .reload_policy(ReloadPolicy::Manual)
155 .try_into()?;
156 Ok(reader)
157 }
158
159 pub fn get_num_docs(&self) -> u64 {
160 let searcher = self.reader.searcher();
161 searcher.num_docs()
162 }
163
164 pub fn search(
168 &self,
169 searcher: Option<&Searcher>,
170 pairs: &[&str],
171 top_k: usize,
172 ) -> Result<Vec<(u64, f64)>> {
173 if pairs.is_empty() {
174 return Ok(Vec::new());
175 }
176 let searcher = if let Some(searcher) = searcher {
177 searcher
178 } else {
179 &self.reader.searcher()
180 };
181 let mut clauses = Vec::with_capacity(pairs.len());
182 for &tok in pairs {
183 let term = Term::from_field_text(self.token_cluster_id, tok);
184 clauses.push(Box::new(TfTermQuery::new(term)) as Box<dyn Query>);
185 }
186 let bool_q = BooleanQuery::union(clauses);
187
188 let hits = searcher.search(&bool_q, &tantivy::collector::TopDocs::with_limit(top_k))?;
189
190 let mut results = Vec::with_capacity(hits.len());
191 for (score, doc_addr) in hits {
192 let retrieved_doc: TantivyDocument = searcher.doc(doc_addr)?;
193 let doc_id: u64 = retrieved_doc
194 .get_first(self.doc_id)
195 .unwrap()
196 .as_u64()
197 .unwrap();
198 results.push((doc_id, score as f64));
199 }
200 Ok(results)
201 }
202}