1use std::collections::{HashMap, HashSet};
2
3use smol_str::SmolStr;
4
5use super::{
6 SNAPSHOT_VERSION,
7 index::Index,
8 pipeline::{DefaultTokenizer, Pipeline},
9 tokenizer::Token,
10 types::{
11 DerivedSpan, DerivedTerm, DocData, DomainLengths, InMemoryIndex, PositionEncoding,
12 SnapshotData, TermDomain, TermFrequencyEntry, TermId, TermPositions, TokenStream,
13 },
14};
15
16type DirtyDoc = (String, String, String, i64);
17type DeletedDoc = HashMap<String, HashSet<String>>;
18
19impl InMemoryIndex {
20 pub fn with_position_encoding(encoding: PositionEncoding) -> Self {
22 Self {
23 position_encoding: encoding,
24 ..Default::default()
25 }
26 }
27
28 pub fn with_dictionary_config(dictionary: crate::tokenizer::DictionaryConfig) -> Self {
30 Self {
31 dictionary: Some(dictionary),
32 ..Default::default()
33 }
34 }
35
36 pub fn set_position_encoding(&mut self, encoding: PositionEncoding) {
38 self.position_encoding = encoding;
39 }
40
41 pub fn set_dictionary_config(
43 &mut self,
44 dictionary: Option<crate::tokenizer::DictionaryConfig>,
45 ) {
46 self.dictionary = dictionary;
47 }
48
49 pub fn add_doc(&mut self, index_name: &str, doc_id: &str, text: &str, index: bool) {
52 let token_stream = if index {
53 self.document_pipeline().document_tokens(text)
54 } else {
55 TokenStream {
56 tokens: Vec::new(),
57 doc_len: 0,
58 }
59 };
60
61 let mut maps = Index {
62 state: self.index_state_mut(index_name),
63 };
64
65 let doc_idx = if let Some(existing) = maps.state.doc_index.get(doc_id) {
66 *existing
67 } else if let Some(reuse) = maps.state.free_docs.pop() {
68 let doc_key = SmolStr::new(doc_id);
69 if let Some(slot) = maps.state.doc_ids.get_mut(reuse as usize) {
70 *slot = doc_key.clone();
71 } else {
72 maps.state
73 .doc_ids
74 .resize(reuse as usize + 1, SmolStr::default());
75 maps.state.doc_ids[reuse as usize] = doc_key.clone();
76 }
77 if maps.state.docs.len() <= reuse as usize {
78 maps.state.docs.resize(reuse as usize + 1, None);
79 }
80 maps.state.doc_index.insert(doc_key, reuse);
81 reuse
82 } else {
83 let doc_key = SmolStr::new(doc_id);
84 let id = maps.state.doc_ids.len() as super::types::DocId;
85 maps.state.doc_ids.push(doc_key.clone());
86 maps.state.docs.push(None);
87 maps.state.doc_index.insert(doc_key, id);
88 id
89 };
90
91 if let Some(old_data) = maps
92 .state
93 .docs
94 .get_mut(doc_idx as usize)
95 .and_then(|slot| slot.take())
96 {
97 maps.state.total_len -= old_data.doc_len;
98 let old_domain_lengths = DomainLengths::from_doc(&old_data);
99 old_domain_lengths.for_each_nonzero(|domain, len| {
100 maps.state.domain_total_len.add(domain, -len);
101 });
102 maps.remove_doc_terms(doc_idx, &old_data);
103 }
104
105 let mut term_pos: HashMap<TermId, Vec<(u32, u32)>> = HashMap::new();
106 let mut derived_candidates: Vec<(TermId, TermId, (u32, u32))> = Vec::new();
107 let mut term_freqs: HashMap<TermId, [u32; super::types::TERM_DOMAIN_COUNT]> =
108 HashMap::new();
109
110 for token in &token_stream.tokens {
111 let term_id = get_or_insert_term_id(maps.state, &token.term);
112 let domain_idx = super::types::domain_index(token.domain);
113 let counts = term_freqs
114 .entry(term_id)
115 .or_insert([0; super::types::TERM_DOMAIN_COUNT]);
116 counts[domain_idx] += 1;
117
118 if token.domain == TermDomain::Original {
119 term_pos
120 .entry(term_id)
121 .or_default()
122 .push((token.span.0 as u32, token.span.1 as u32));
123 } else {
124 let base_term_id = get_or_insert_term_id(maps.state, &token.base_term);
125 derived_candidates.push((
126 term_id,
127 base_term_id,
128 (token.span.0 as u32, token.span.1 as u32),
129 ));
130 }
131 }
132
133 let mut term_positions: Vec<TermPositions> = term_pos
134 .into_iter()
135 .map(|(term, mut positions)| {
136 positions.sort();
137 positions.dedup();
138 TermPositions { term, positions }
139 })
140 .collect();
141 term_positions.sort_by_key(|entry| entry.term);
142
143 let base_terms: HashSet<TermId> = term_positions.iter().map(|entry| entry.term).collect();
144 let mut derived_terms: Vec<DerivedTerm> = Vec::new();
145 let mut derived_spans_map: HashMap<TermId, (u32, u32)> = HashMap::new();
146 for (derived, base, span) in derived_candidates {
147 if base_terms.contains(&base) {
148 derived_terms.push(DerivedTerm { derived, base });
149 } else {
150 let span_len = span.1.saturating_sub(span.0);
151 derived_spans_map
152 .entry(derived)
153 .and_modify(|existing| {
154 let existing_len = existing.1.saturating_sub(existing.0);
155 if span_len < existing_len {
156 *existing = span;
157 }
158 })
159 .or_insert(span);
160 }
161 }
162 derived_terms.sort_by(|a, b| (a.derived, a.base).cmp(&(b.derived, b.base)));
163 derived_terms.dedup_by(|a, b| a.derived == b.derived && a.base == b.base);
164 let mut derived_spans: Vec<DerivedSpan> = derived_spans_map
165 .into_iter()
166 .map(|(derived, span)| DerivedSpan { derived, span })
167 .collect();
168 derived_spans.sort_by_key(|entry| entry.derived);
169
170 let mut term_freqs_vec: Vec<TermFrequencyEntry> = term_freqs
171 .into_iter()
172 .map(|(term, counts)| TermFrequencyEntry { term, counts })
173 .collect();
174 term_freqs_vec.sort_by_key(|entry| entry.term);
175
176 let doc_len = token_stream.doc_len;
177 let mut domain_doc_len = DomainLengths::from_term_freqs(&term_freqs_vec);
178 if domain_doc_len.is_zero() {
179 domain_doc_len.add(TermDomain::Original, doc_len);
180 }
181
182 for entry in &term_freqs_vec {
183 for (domain, count) in entry.positive_domains() {
184 maps.add_posting(entry.term, domain, doc_idx, count);
185 }
186 }
187
188 let doc_data = DocData {
189 content: text.to_string(),
190 doc_len,
191 term_pos: term_positions,
192 term_freqs: term_freqs_vec,
193 domain_doc_len,
194 derived_terms,
195 derived_spans,
196 };
197
198 if maps.state.docs.len() <= doc_idx as usize {
199 maps.state.docs.resize(doc_idx as usize + 1, None);
200 }
201 maps.state.docs[doc_idx as usize] = Some(doc_data);
202
203 maps.state.total_len += doc_len;
204 domain_doc_len.for_each_nonzero(|domain, len| {
205 maps.state.domain_total_len.add(domain, len);
206 });
207
208 let doc_key = maps
209 .state
210 .doc_ids
211 .get(doc_idx as usize)
212 .cloned()
213 .unwrap_or_else(|| SmolStr::new(doc_id));
214 maps.state.dirty.insert(doc_key.clone());
215 maps.state.deleted.remove(doc_key.as_str());
216 }
217
218 pub fn remove_doc(&mut self, index_name: &str, doc_id: &str) {
220 let mut maps = Index {
221 state: self.index_state_mut(index_name),
222 };
223 let Some(&doc_idx) = maps.state.doc_index.get(doc_id) else {
224 return;
225 };
226
227 if let Some(old_data) = maps
228 .state
229 .docs
230 .get_mut(doc_idx as usize)
231 .and_then(|slot| slot.take())
232 {
233 maps.state.total_len -= old_data.doc_len;
234 let old_domain_lengths = DomainLengths::from_doc(&old_data);
235 old_domain_lengths.for_each_nonzero(|domain, len| {
236 maps.state.domain_total_len.add(domain, -len);
237 });
238 maps.remove_doc_terms(doc_idx, &old_data);
239 }
240
241 maps.state.doc_index.remove(doc_id);
242 maps.state.free_docs.push(doc_idx);
243 let doc_key = maps
244 .state
245 .doc_ids
246 .get(doc_idx as usize)
247 .cloned()
248 .unwrap_or_else(|| SmolStr::new(doc_id));
249 maps.state.deleted.insert(doc_key);
250 maps.state.dirty.remove(doc_id);
251 }
252
253 pub fn get_doc(&self, index_name: &str, doc_id: &str) -> Option<String> {
255 let state = self.indexes.get(index_name)?;
256 let doc_idx = *state.doc_index.get(doc_id)? as usize;
257 state
258 .docs
259 .get(doc_idx)
260 .and_then(|doc| doc.as_ref())
261 .map(|d| d.content.clone())
262 }
263
264 pub fn take_dirty_and_deleted(&mut self) -> (Vec<DirtyDoc>, DeletedDoc) {
266 let mut dirty_data = Vec::new();
267 let mut deleted = HashMap::new();
268
269 for (index_name, state) in self.indexes.iter_mut() {
270 let dirty = std::mem::take(&mut state.dirty);
271 let deleted_ids = std::mem::take(&mut state.deleted);
272
273 for doc_id in dirty {
274 if let Some(&doc_idx) = state.doc_index.get(&doc_id)
275 && let Some(doc) = state
276 .docs
277 .get(doc_idx as usize)
278 .and_then(|entry| entry.as_ref())
279 {
280 dirty_data.push((
281 index_name.clone(),
282 doc_id.to_string(),
283 doc.content.clone(),
284 doc.doc_len,
285 ));
286 }
287 }
288
289 if !deleted_ids.is_empty() {
290 let deleted_strings: HashSet<String> = deleted_ids
291 .into_iter()
292 .map(|doc_id| doc_id.to_string())
293 .collect();
294 deleted.insert(index_name.clone(), deleted_strings);
295 }
296 }
297
298 (dirty_data, deleted)
299 }
300
301 pub fn has_unpersisted_changes(&self, index_name: Option<&str>) -> bool {
304 match index_name {
305 Some(name) => self
306 .indexes
307 .get(name)
308 .is_some_and(|state| !state.dirty.is_empty() || !state.deleted.is_empty()),
309 None => self
310 .indexes
311 .values()
312 .any(|state| !state.dirty.is_empty() || !state.deleted.is_empty()),
313 }
314 }
315
316 pub fn persist_if_dirty<E>(
321 &mut self,
322 index_name: &str,
323 mut persist: impl FnMut(SnapshotData) -> Result<(), E>,
324 ) -> Result<bool, E> {
325 if !self.has_unpersisted_changes(Some(index_name)) {
326 return Ok(false);
327 }
328
329 let Some(snapshot) = self.get_snapshot_data(index_name) else {
330 return Ok(false);
331 };
332
333 persist(snapshot)?;
334 if let Some(state) = self.indexes.get_mut(index_name) {
335 state.dirty.clear();
336 state.deleted.clear();
337 }
338 Ok(true)
339 }
340
341 pub fn get_matches(&self, index_name: &str, doc_id: &str, query: &str) -> Vec<(u32, u32)> {
343 let query_terms: Vec<String> = self
344 .tokenize_query(query)
345 .into_iter()
346 .map(|t| t.term)
347 .collect();
348 self.get_matches_for_terms(index_name, doc_id, &query_terms)
349 }
350
351 pub fn get_matches_for_terms(
353 &self,
354 index_name: &str,
355 doc_id: &str,
356 terms: &[String],
357 ) -> Vec<(u32, u32)> {
358 let mut matches = Vec::new();
359 let Some(state) = self.indexes.get(index_name) else {
360 return matches;
361 };
362 let Some(&doc_idx) = state.doc_index.get(doc_id) else {
363 return matches;
364 };
365 let Some(doc_data) = state
366 .docs
367 .get(doc_idx as usize)
368 .and_then(|doc| doc.as_ref())
369 else {
370 return matches;
371 };
372
373 for term in terms {
374 let Some(&term_id) = state.term_index.get(term.as_str()) else {
375 continue;
376 };
377
378 let mut found = false;
379 if let Some(positions) = find_term_positions(doc_data, term_id) {
380 matches.extend(positions.iter().copied());
381 found = true;
382 }
383
384 if !found {
385 for base_term in find_base_terms(doc_data, term_id) {
386 if let Some(positions) = find_term_positions(doc_data, base_term) {
387 matches.extend(positions.iter().copied());
388 found = true;
389 }
390 }
391 }
392
393 if !found {
394 matches.extend(find_derived_spans(doc_data, term_id));
395 }
396 }
397
398 if !matches.is_empty() {
399 matches = convert_spans(&doc_data.content, &matches, self.position_encoding);
400 }
401 matches.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| (a.1 - a.0).cmp(&(b.1 - b.0))));
402 matches = prune_overlapping_starts(&matches);
403 matches
404 }
405
406 pub fn get_matches_for_matched_terms(
408 &self,
409 index_name: &str,
410 doc_id: &str,
411 terms: &[crate::types::MatchedTerm],
412 ) -> Vec<(u32, u32)> {
413 let term_strings: Vec<String> = terms.iter().map(|t| t.term.clone()).collect();
414 self.get_matches_for_terms(index_name, doc_id, &term_strings)
415 }
416
417 pub fn load_snapshot(&mut self, index_name: &str, snapshot: SnapshotData) {
419 if snapshot.version != SNAPSHOT_VERSION {
420 return;
421 }
422 let version = {
423 let mut maps = Index {
424 state: self.index_state_mut(index_name),
425 };
426 maps.clear();
427 maps.import_snapshot(snapshot);
428 maps.state.version
429 };
430 if let Some(state) = self.indexes.get_mut(index_name) {
431 state.version = version;
432 state.dirty.clear();
433 state.deleted.clear();
434 }
435 }
436
437 pub fn get_snapshot_data(&self, index_name: &str) -> Option<SnapshotData> {
439 let state = self.indexes.get(index_name)?;
440 if state.docs.iter().all(|d| d.is_none()) {
441 return None;
442 }
443
444 Some(SnapshotData {
445 version: state.version,
446 terms: state.terms.clone(),
447 docs: state.docs.clone(),
448 doc_ids: state.doc_ids.clone(),
449 domains: state.domains.clone(),
450 total_len: state.total_len,
451 domain_total_len: state.domain_total_len,
452 })
453 }
454
455 fn document_pipeline(&self) -> Pipeline {
456 if let Some(cfg) = &self.dictionary {
457 Pipeline::with_dictionary(cfg.clone())
458 } else {
459 Pipeline::document_pipeline()
460 }
461 }
462
463 pub(super) fn tokenize_query(&self, query: &str) -> Vec<Token> {
464 if let Some(cfg) = &self.dictionary {
465 Pipeline::new(DefaultTokenizer::for_queries().with_dictionary(cfg.clone()))
466 .query_tokens(query)
467 .tokens
468 .into_iter()
469 .map(|token| Token {
470 term: token.term,
471 start: token.span.0,
472 end: token.span.1,
473 })
474 .collect()
475 } else {
476 Pipeline::tokenize_query(query)
477 }
478 }
479}
480
481fn get_or_insert_term_id(state: &mut super::types::IndexState, term: &str) -> TermId {
482 if let Some(&id) = state.term_index.get(term) {
483 return id;
484 }
485 let id = state.terms.len() as TermId;
486 let term_key = SmolStr::new(term);
487 state.terms.push(term_key.clone());
488 state.term_index.insert(term_key, id);
489 id
490}
491
492fn find_term_positions(doc: &DocData, term: TermId) -> Option<&[(u32, u32)]> {
493 let idx = doc
494 .term_pos
495 .binary_search_by_key(&term, |entry| entry.term)
496 .ok()?;
497 Some(&doc.term_pos[idx].positions)
498}
499
500fn find_base_terms(doc: &DocData, derived: TermId) -> Vec<TermId> {
501 let list = &doc.derived_terms;
502 let mut start = match list.binary_search_by_key(&derived, |entry| entry.derived) {
503 Ok(idx) => idx,
504 Err(_) => return Vec::new(),
505 };
506 while start > 0 && list[start - 1].derived == derived {
507 start -= 1;
508 }
509 let mut terms = Vec::new();
510 let mut idx = start;
511 while idx < list.len() && list[idx].derived == derived {
512 terms.push(list[idx].base);
513 idx += 1;
514 }
515 terms
516}
517
518fn find_derived_spans(doc: &DocData, derived: TermId) -> Vec<(u32, u32)> {
519 let list = &doc.derived_spans;
520 let mut start = match list.binary_search_by_key(&derived, |entry| entry.derived) {
521 Ok(idx) => idx,
522 Err(_) => return Vec::new(),
523 };
524 while start > 0 && list[start - 1].derived == derived {
525 start -= 1;
526 }
527 let mut spans = Vec::new();
528 let mut idx = start;
529 while idx < list.len() && list[idx].derived == derived {
530 spans.push(list[idx].span);
531 idx += 1;
532 }
533 spans
534}
535
536fn convert_spans(
537 content: &str,
538 spans: &[(u32, u32)],
539 encoding: PositionEncoding,
540) -> Vec<(u32, u32)> {
541 match encoding {
542 PositionEncoding::Bytes => spans.to_vec(),
543 PositionEncoding::Utf16 => spans
544 .iter()
545 .map(|(start, end)| {
546 let s = to_utf16_index(content, *start as usize);
547 let e = to_utf16_index(content, *end as usize);
548 (s as u32, e as u32)
549 })
550 .collect(),
551 }
552}
553
554fn to_utf16_index(content: &str, byte_idx: usize) -> usize {
555 if byte_idx == 0 {
556 return 0;
557 }
558 let prefix = &content[..byte_idx.min(content.len())];
559 prefix.encode_utf16().count()
560}
561
562fn prune_overlapping_starts(spans: &[(u32, u32)]) -> Vec<(u32, u32)> {
563 if spans.is_empty() {
564 return Vec::new();
565 }
566 let mut pruned = Vec::new();
567 let mut i = 0;
568 while i < spans.len() {
569 let start = spans[i].0;
570 let mut best = spans[i];
571 let mut j = i + 1;
572 while j < spans.len() && spans[j].0 == start {
573 let best_len = best.1 - best.0;
574 let cur_len = spans[j].1 - spans[j].0;
575 if cur_len < best_len {
576 best = spans[j];
577 }
578 j += 1;
579 }
580 pruned.push(best);
581 i = j;
582 }
583 pruned
584}