1use std::collections::{HashMap, HashSet};
2
3use super::{
4 pipeline::{DefaultTokenizer, Pipeline},
5 tokenizer::Token,
6 types::{
7 DocData, DomainLengths, InMemoryIndex, PipelineToken, PositionEncoding, SNAPSHOT_VERSION,
8 SnapshotData, TermDomain, TokenStream,
9 },
10};
11
12type DirtyDoc = (String, String, String, i64);
13type DeletedDoc = HashMap<String, HashSet<String>>;
14
15impl InMemoryIndex {
16 pub fn with_position_encoding(encoding: PositionEncoding) -> Self {
18 let mut index = Self::default();
19 index.position_encoding = encoding;
20 index
21 }
22
23 pub fn with_dictionary_config(dictionary: crate::tokenizer::DictionaryConfig) -> Self {
25 let mut index = Self::default();
26 index.dictionary = Some(dictionary);
27 index
28 }
29
30 pub fn set_position_encoding(&mut self, encoding: PositionEncoding) {
32 self.position_encoding = encoding;
33 }
34
35 pub fn set_dictionary_config(
37 &mut self,
38 dictionary: Option<crate::tokenizer::DictionaryConfig>,
39 ) {
40 self.dictionary = dictionary;
41 }
42
43 pub fn add_doc(&mut self, index_name: &str, doc_id: &str, text: &str, index: bool) {
46 let token_stream = if index {
47 self.document_pipeline().document_tokens(text)
48 } else {
49 TokenStream {
50 tokens: Vec::new(),
51 term_freqs: HashMap::new(),
52 doc_len: 0,
53 }
54 };
55
56 let mut pos_map: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
57 let mut derived_mapping: HashMap<String, HashSet<(u32, u32)>> = HashMap::new();
58 for PipelineToken {
59 term, span, domain, ..
60 } in &token_stream.tokens
61 {
62 if *domain == TermDomain::Original {
63 pos_map
64 .entry(term.clone())
65 .or_default()
66 .push((span.0 as u32, span.1 as u32));
67 } else {
68 derived_mapping
69 .entry(term.clone())
70 .or_default()
71 .insert((span.0 as u32, span.1 as u32));
72 }
73 }
74 let doc_len = token_stream.doc_len;
75 let term_freqs = token_stream.term_freqs;
76 let mut domain_doc_len = DomainLengths::from_term_freqs(&term_freqs);
77 if domain_doc_len.is_zero() {
78 domain_doc_len.add(TermDomain::Original, doc_len);
79 }
80
81 if let Some(docs) = self.docs.get_mut(index_name) {
82 if let Some(old_data) = docs.remove(doc_id) {
83 *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
84
85 let old_domain_lengths = DomainLengths::from_doc(&old_data);
86 if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
87 old_domain_lengths.for_each_nonzero(|domain, len| {
88 total_by_domain.add(domain, -len);
89 });
90 }
91
92 self.index_maps_mut(index_name)
93 .remove_doc_terms(doc_id, &old_data);
94 }
95 }
96
97 let mut writer = self.index_writer(index_name, doc_id);
98 for (term, freqs) in &term_freqs {
99 writer.add_term_frequency(term, freqs);
100 }
101
102 let doc_data = DocData {
103 content: text.to_string(),
104 doc_len,
105 term_pos: pos_map,
106 term_freqs,
107 domain_doc_len: domain_doc_len.clone(),
108 derived_terms: derived_mapping
109 .into_iter()
110 .map(|(k, v)| {
111 let mut spans: Vec<(u32, u32)> = v.into_iter().collect();
112 spans.sort();
113 spans.dedup();
114 if let Some(min_len) = spans.iter().map(|(s, e)| e - s).min() {
115 spans.retain(|(s, e)| e - s == min_len);
116 }
117 (k, spans)
118 })
119 .collect(),
120 };
121
122 self.docs
123 .entry(index_name.to_string())
124 .or_default()
125 .insert(doc_id.to_string(), doc_data);
126 *self.total_lens.entry(index_name.to_string()).or_default() += doc_len;
127 let total_by_domain = self
128 .domain_total_lens
129 .entry(index_name.to_string())
130 .or_default();
131 domain_doc_len.for_each_nonzero(|domain, len| {
132 total_by_domain.add(domain, len);
133 });
134
135 self.dirty
136 .entry(index_name.to_string())
137 .or_default()
138 .insert(doc_id.to_string());
139 if let Some(deleted) = self.deleted.get_mut(index_name) {
140 deleted.remove(doc_id);
141 }
142 }
143
144 pub fn remove_doc(&mut self, index_name: &str, doc_id: &str) {
146 if let Some(docs) = self.docs.get_mut(index_name) {
147 if let Some(old_data) = docs.remove(doc_id) {
148 *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
149
150 let old_domain_lengths = DomainLengths::from_doc(&old_data);
151 if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
152 old_domain_lengths.for_each_nonzero(|domain, len| {
153 total_by_domain.add(domain, -len);
154 });
155 }
156
157 self.index_maps_mut(index_name)
158 .remove_doc_terms(doc_id, &old_data);
159
160 self.deleted
161 .entry(index_name.to_string())
162 .or_default()
163 .insert(doc_id.to_string());
164 if let Some(dirty) = self.dirty.get_mut(index_name) {
165 dirty.remove(doc_id);
166 }
167 }
168 }
169 }
170
171 pub fn get_doc(&self, index_name: &str, doc_id: &str) -> Option<String> {
173 self.docs
174 .get(index_name)
175 .and_then(|docs| docs.get(doc_id))
176 .map(|d| d.content.clone())
177 }
178
179 pub fn take_dirty_and_deleted(&mut self) -> (Vec<DirtyDoc>, DeletedDoc) {
181 let dirty = std::mem::take(&mut self.dirty);
182 let deleted = std::mem::take(&mut self.deleted);
183
184 let mut dirty_data = Vec::new();
185 for (index_name, doc_ids) in &dirty {
186 if let Some(docs) = self.docs.get(index_name) {
187 for doc_id in doc_ids {
188 if let Some(data) = docs.get(doc_id) {
189 dirty_data.push((
190 index_name.clone(),
191 doc_id.clone(),
192 data.content.clone(),
193 data.doc_len,
194 ));
195 }
196 }
197 }
198 }
199 (dirty_data, deleted)
200 }
201
202 pub fn get_matches(&self, index_name: &str, doc_id: &str, query: &str) -> Vec<(u32, u32)> {
204 let query_terms: Vec<String> = self
205 .tokenize_query(query)
206 .into_iter()
207 .map(|t| t.term)
208 .collect();
209 self.get_matches_for_terms(index_name, doc_id, &query_terms)
210 }
211
212 pub fn get_matches_for_terms(
214 &self,
215 index_name: &str,
216 doc_id: &str,
217 terms: &[String],
218 ) -> Vec<(u32, u32)> {
219 let mut matches = Vec::new();
220 if let Some(docs) = self.docs.get(index_name) {
221 if let Some(doc_data) = docs.get(doc_id) {
222 for term in terms {
223 if let Some(positions) = doc_data.term_pos.get(term) {
224 matches.extend(positions.iter().cloned());
225 continue;
226 }
227 if let Some(positions) = doc_data.derived_terms.get(term) {
228 matches.extend(positions.iter().cloned());
229 }
230 }
231 if !matches.is_empty() {
232 matches = convert_spans(&doc_data.content, &matches, self.position_encoding);
233 }
234 }
235 }
236 matches.sort_by(|a, b| a.0.cmp(&b.0));
237 matches
238 }
239
240 pub fn get_matches_for_matched_terms(
242 &self,
243 index_name: &str,
244 doc_id: &str,
245 terms: &[crate::types::MatchedTerm],
246 ) -> Vec<(u32, u32)> {
247 let term_strings: Vec<String> = terms.iter().map(|t| t.term.clone()).collect();
248 self.get_matches_for_terms(index_name, doc_id, &term_strings)
249 }
250
251 pub fn load_snapshot(&mut self, index_name: &str, snapshot: SnapshotData) {
253 assert_eq!(
254 snapshot.version, SNAPSHOT_VERSION,
255 "snapshot version {} does not match expected {}",
256 snapshot.version, SNAPSHOT_VERSION
257 );
258 let version = {
259 let mut maps = self.index_maps_mut(index_name);
260 maps.clear(false);
261 maps.import_snapshot(snapshot);
262 maps.version
263 };
264 self.versions.insert(index_name.to_string(), version);
265 }
266
267 pub fn get_snapshot_data(&self, index_name: &str) -> Option<SnapshotData> {
269 self.docs.get(index_name).map(|docs| {
270 let domains = self.domains.get(index_name).cloned().unwrap_or_default();
271
272 SnapshotData {
273 version: *self.versions.get(index_name).unwrap_or(&SNAPSHOT_VERSION),
274 docs: docs.clone(),
275 total_len: *self.total_lens.get(index_name).unwrap_or(&0),
276 domain_total_len: self
277 .domain_total_lens
278 .get(index_name)
279 .cloned()
280 .unwrap_or_default(),
281 domains,
282 }
283 })
284 }
285
286 fn document_pipeline(&self) -> Pipeline {
287 if let Some(cfg) = &self.dictionary {
288 Pipeline::with_dictionary(cfg.clone())
289 } else {
290 Pipeline::document_pipeline()
291 }
292 }
293
294 pub(super) fn tokenize_query(&self, query: &str) -> Vec<Token> {
295 if let Some(cfg) = &self.dictionary {
296 Pipeline::new(DefaultTokenizer::for_queries().with_dictionary(cfg.clone()))
297 .query_tokens(query)
298 .tokens
299 .into_iter()
300 .map(|token| Token {
301 term: token.term,
302 start: token.span.0,
303 end: token.span.1,
304 })
305 .collect()
306 } else {
307 Pipeline::tokenize_query(query)
308 }
309 }
310}
311
312fn convert_spans(
313 content: &str,
314 spans: &[(u32, u32)],
315 encoding: PositionEncoding,
316) -> Vec<(u32, u32)> {
317 match encoding {
318 PositionEncoding::Bytes => spans.to_vec(),
319 PositionEncoding::Utf16 => spans
320 .iter()
321 .map(|(start, end)| {
322 let s = to_utf16_index(content, *start as usize);
323 let e = to_utf16_index(content, *end as usize);
324 (s as u32, e as u32)
325 })
326 .collect(),
327 }
328}
329
330fn to_utf16_index(content: &str, byte_idx: usize) -> usize {
331 if byte_idx == 0 {
332 return 0;
333 }
334 let prefix = &content[..byte_idx.min(content.len())];
335 prefix.encode_utf16().count()
336}