1use std::collections::{HashMap, HashSet};
2
3use super::{
4 pipeline::{DefaultTokenizer, Pipeline},
5 tokenizer::Token,
6 types::{
7 DocData, DomainLengths, InMemoryIndex, PipelineToken, PositionEncoding, SNAPSHOT_VERSION,
8 SnapshotData, TermDomain, TokenStream,
9 },
10};
11
12type DirtyDoc = (String, String, String, i64);
13type DeletedDoc = HashMap<String, HashSet<String>>;
14
15impl InMemoryIndex {
16 pub fn with_position_encoding(encoding: PositionEncoding) -> Self {
18 let mut index = Self::default();
19 index.position_encoding = encoding;
20 index
21 }
22
23 pub fn with_dictionary_config(dictionary: crate::tokenizer::DictionaryConfig) -> Self {
25 let mut index = Self::default();
26 index.dictionary = Some(dictionary);
27 index
28 }
29
30 pub fn set_position_encoding(&mut self, encoding: PositionEncoding) {
32 self.position_encoding = encoding;
33 }
34
35 pub fn set_dictionary_config(
37 &mut self,
38 dictionary: Option<crate::tokenizer::DictionaryConfig>,
39 ) {
40 self.dictionary = dictionary;
41 }
42
43 pub fn add_doc(&mut self, index_name: &str, doc_id: &str, text: &str, index: bool) {
46 let token_stream = if index {
47 self.document_pipeline().document_tokens(text)
48 } else {
49 TokenStream {
50 tokens: Vec::new(),
51 term_freqs: HashMap::new(),
52 doc_len: 0,
53 }
54 };
55
56 let mut pos_map: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
57 let mut derived_mapping: HashMap<String, HashSet<(u32, u32)>> = HashMap::new();
58 for PipelineToken {
59 term, span, domain, ..
60 } in &token_stream.tokens
61 {
62 if *domain == TermDomain::Original {
63 pos_map
64 .entry(term.clone())
65 .or_default()
66 .push((span.0 as u32, span.1 as u32));
67 } else {
68 derived_mapping
69 .entry(term.clone())
70 .or_default()
71 .insert((span.0 as u32, span.1 as u32));
72 }
73 }
74 let doc_len = token_stream.doc_len;
75 let term_freqs = token_stream.term_freqs;
76 let mut domain_doc_len = DomainLengths::from_term_freqs(&term_freqs);
77 if domain_doc_len.is_zero() {
78 domain_doc_len.add(TermDomain::Original, doc_len);
79 }
80
81 if let Some(docs) = self.docs.get_mut(index_name) {
82 if let Some(old_data) = docs.remove(doc_id) {
83 *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
84
85 let old_domain_lengths = DomainLengths::from_doc(&old_data);
86 if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
87 old_domain_lengths.for_each_nonzero(|domain, len| {
88 total_by_domain.add(domain, -len);
89 });
90 }
91
92 self.index_maps_mut(index_name)
93 .remove_doc_terms(doc_id, &old_data);
94 }
95 }
96
97 let mut writer = self.index_writer(index_name, doc_id);
98 for (term, freqs) in &term_freqs {
99 writer.add_term_frequency(term, freqs);
100 }
101
102 let doc_data = DocData {
103 content: text.to_string(),
104 doc_len,
105 term_pos: pos_map,
106 term_freqs,
107 domain_doc_len: domain_doc_len.clone(),
108 derived_terms: derived_mapping
109 .into_iter()
110 .map(|(k, v)| {
111 let mut spans: Vec<(u32, u32)> = v.into_iter().collect();
112 spans.sort();
113 spans.dedup();
114 if let Some(min_len) = spans.iter().map(|(s, e)| e - s).min() {
115 spans.retain(|(s, e)| e - s == min_len);
116 }
117 (k, spans)
118 })
119 .collect(),
120 };
121
122 self.docs
123 .entry(index_name.to_string())
124 .or_default()
125 .insert(doc_id.to_string(), doc_data);
126 *self.total_lens.entry(index_name.to_string()).or_default() += doc_len;
127 let total_by_domain = self
128 .domain_total_lens
129 .entry(index_name.to_string())
130 .or_default();
131 domain_doc_len.for_each_nonzero(|domain, len| {
132 total_by_domain.add(domain, len);
133 });
134
135 self.dirty
136 .entry(index_name.to_string())
137 .or_default()
138 .insert(doc_id.to_string());
139 if let Some(deleted) = self.deleted.get_mut(index_name) {
140 deleted.remove(doc_id);
141 }
142 }
143
144 pub fn remove_doc(&mut self, index_name: &str, doc_id: &str) {
146 if let Some(docs) = self.docs.get_mut(index_name) {
147 if let Some(old_data) = docs.remove(doc_id) {
148 *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
149
150 let old_domain_lengths = DomainLengths::from_doc(&old_data);
151 if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
152 old_domain_lengths.for_each_nonzero(|domain, len| {
153 total_by_domain.add(domain, -len);
154 });
155 }
156
157 self.index_maps_mut(index_name)
158 .remove_doc_terms(doc_id, &old_data);
159
160 self.deleted
161 .entry(index_name.to_string())
162 .or_default()
163 .insert(doc_id.to_string());
164 if let Some(dirty) = self.dirty.get_mut(index_name) {
165 dirty.remove(doc_id);
166 }
167 }
168 }
169 }
170
171 pub fn get_doc(&self, index_name: &str, doc_id: &str) -> Option<String> {
173 self.docs
174 .get(index_name)
175 .and_then(|docs| docs.get(doc_id))
176 .map(|d| d.content.clone())
177 }
178
179 pub fn take_dirty_and_deleted(&mut self) -> (Vec<DirtyDoc>, DeletedDoc) {
181 let dirty = std::mem::take(&mut self.dirty);
182 let deleted = std::mem::take(&mut self.deleted);
183
184 let mut dirty_data = Vec::new();
185 for (index_name, doc_ids) in &dirty {
186 if let Some(docs) = self.docs.get(index_name) {
187 for doc_id in doc_ids {
188 if let Some(data) = docs.get(doc_id) {
189 dirty_data.push((
190 index_name.clone(),
191 doc_id.clone(),
192 data.content.clone(),
193 data.doc_len,
194 ));
195 }
196 }
197 }
198 }
199 (dirty_data, deleted)
200 }
201
202 pub fn has_unpersisted_changes(&self, index_name: Option<&str>) -> bool {
205 match index_name {
206 Some(name) => {
207 self.dirty.get(name).map_or(false, |s| !s.is_empty())
208 || self.deleted.get(name).map_or(false, |s| !s.is_empty())
209 }
210 None => {
211 self.dirty.values().any(|s| !s.is_empty())
212 || self.deleted.values().any(|s| !s.is_empty())
213 }
214 }
215 }
216
217 pub fn persist_if_dirty<E>(
222 &mut self,
223 index_name: &str,
224 mut persist: impl FnMut(SnapshotData) -> Result<(), E>,
225 ) -> Result<bool, E> {
226 if !self.has_unpersisted_changes(Some(index_name)) {
227 return Ok(false);
228 }
229
230 let Some(snapshot) = self.get_snapshot_data(index_name) else {
231 return Ok(false);
232 };
233
234 persist(snapshot)?;
235 self.dirty.remove(index_name);
236 self.deleted.remove(index_name);
237 Ok(true)
238 }
239
240 pub fn get_matches(&self, index_name: &str, doc_id: &str, query: &str) -> Vec<(u32, u32)> {
242 let query_terms: Vec<String> = self
243 .tokenize_query(query)
244 .into_iter()
245 .map(|t| t.term)
246 .collect();
247 self.get_matches_for_terms(index_name, doc_id, &query_terms)
248 }
249
250 pub fn get_matches_for_terms(
252 &self,
253 index_name: &str,
254 doc_id: &str,
255 terms: &[String],
256 ) -> Vec<(u32, u32)> {
257 let mut matches = Vec::new();
258 if let Some(docs) = self.docs.get(index_name) {
259 if let Some(doc_data) = docs.get(doc_id) {
260 for term in terms {
261 if let Some(positions) = doc_data.term_pos.get(term) {
262 matches.extend(positions.iter().cloned());
263 continue;
264 }
265 if let Some(positions) = doc_data.derived_terms.get(term) {
266 matches.extend(positions.iter().cloned());
267 }
268 }
269 if !matches.is_empty() {
270 matches = convert_spans(&doc_data.content, &matches, self.position_encoding);
271 }
272 }
273 }
274 matches.sort_by(|a, b| a.0.cmp(&b.0));
275 matches
276 }
277
278 pub fn get_matches_for_matched_terms(
280 &self,
281 index_name: &str,
282 doc_id: &str,
283 terms: &[crate::types::MatchedTerm],
284 ) -> Vec<(u32, u32)> {
285 let term_strings: Vec<String> = terms.iter().map(|t| t.term.clone()).collect();
286 self.get_matches_for_terms(index_name, doc_id, &term_strings)
287 }
288
289 pub fn load_snapshot(&mut self, index_name: &str, snapshot: SnapshotData) {
291 assert_eq!(
292 snapshot.version, SNAPSHOT_VERSION,
293 "snapshot version {} does not match expected {}",
294 snapshot.version, SNAPSHOT_VERSION
295 );
296 let version = {
297 let mut maps = self.index_maps_mut(index_name);
298 maps.clear(false);
299 maps.import_snapshot(snapshot);
300 maps.version
301 };
302 self.versions.insert(index_name.to_string(), version);
303 self.dirty.remove(index_name);
304 self.deleted.remove(index_name);
305 }
306
307 pub fn get_snapshot_data(&self, index_name: &str) -> Option<SnapshotData> {
309 self.docs.get(index_name).map(|docs| {
310 let domains = self.domains.get(index_name).cloned().unwrap_or_default();
311
312 SnapshotData {
313 version: *self.versions.get(index_name).unwrap_or(&SNAPSHOT_VERSION),
314 docs: docs.clone(),
315 total_len: *self.total_lens.get(index_name).unwrap_or(&0),
316 domain_total_len: self
317 .domain_total_lens
318 .get(index_name)
319 .cloned()
320 .unwrap_or_default(),
321 domains,
322 }
323 })
324 }
325
326 fn document_pipeline(&self) -> Pipeline {
327 if let Some(cfg) = &self.dictionary {
328 Pipeline::with_dictionary(cfg.clone())
329 } else {
330 Pipeline::document_pipeline()
331 }
332 }
333
334 pub(super) fn tokenize_query(&self, query: &str) -> Vec<Token> {
335 if let Some(cfg) = &self.dictionary {
336 Pipeline::new(DefaultTokenizer::for_queries().with_dictionary(cfg.clone()))
337 .query_tokens(query)
338 .tokens
339 .into_iter()
340 .map(|token| Token {
341 term: token.term,
342 start: token.span.0,
343 end: token.span.1,
344 })
345 .collect()
346 } else {
347 Pipeline::tokenize_query(query)
348 }
349 }
350}
351
352fn convert_spans(
353 content: &str,
354 spans: &[(u32, u32)],
355 encoding: PositionEncoding,
356) -> Vec<(u32, u32)> {
357 match encoding {
358 PositionEncoding::Bytes => spans.to_vec(),
359 PositionEncoding::Utf16 => spans
360 .iter()
361 .map(|(start, end)| {
362 let s = to_utf16_index(content, *start as usize);
363 let e = to_utf16_index(content, *end as usize);
364 (s as u32, e as u32)
365 })
366 .collect(),
367 }
368}
369
370fn to_utf16_index(content: &str, byte_idx: usize) -> usize {
371 if byte_idx == 0 {
372 return 0;
373 }
374 let prefix = &content[..byte_idx.min(content.len())];
375 prefix.encode_utf16().count()
376}