memory_indexer/
base.rs

1use std::collections::{HashMap, HashSet};
2
3use super::{
4    pipeline::{DefaultTokenizer, Pipeline},
5    tokenizer::Token,
6    types::{
7        DocData, DomainLengths, InMemoryIndex, PipelineToken, PositionEncoding, SNAPSHOT_VERSION,
8        SnapshotData, TermDomain, TokenStream,
9    },
10};
11
12type DirtyDoc = (String, String, String, i64);
13type DeletedDoc = HashMap<String, HashSet<String>>;
14
15impl InMemoryIndex {
16    /// Create an index that returns match spans in the given encoding.
17    pub fn with_position_encoding(encoding: PositionEncoding) -> Self {
18        let mut index = Self::default();
19        index.position_encoding = encoding;
20        index
21    }
22
23    /// Create an index that uses a custom dictionary for tokenization.
24    pub fn with_dictionary_config(dictionary: crate::tokenizer::DictionaryConfig) -> Self {
25        let mut index = Self::default();
26        index.dictionary = Some(dictionary);
27        index
28    }
29
30    /// Set the encoding (bytes or UTF-16) used when returning match spans.
31    pub fn set_position_encoding(&mut self, encoding: PositionEncoding) {
32        self.position_encoding = encoding;
33    }
34
35    /// Swap in or remove a dictionary config for future tokenization.
36    pub fn set_dictionary_config(
37        &mut self,
38        dictionary: Option<crate::tokenizer::DictionaryConfig>,
39    ) {
40        self.dictionary = dictionary;
41    }
42
43    /// Add or replace a document in an index. Set `index` to false to stage content without
44    /// tokenization (doc will exist but not be searchable).
45    pub fn add_doc(&mut self, index_name: &str, doc_id: &str, text: &str, index: bool) {
46        let token_stream = if index {
47            self.document_pipeline().document_tokens(text)
48        } else {
49            TokenStream {
50                tokens: Vec::new(),
51                term_freqs: HashMap::new(),
52                doc_len: 0,
53            }
54        };
55
56        let mut pos_map: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
57        let mut derived_mapping: HashMap<String, HashSet<(u32, u32)>> = HashMap::new();
58        for PipelineToken {
59            term, span, domain, ..
60        } in &token_stream.tokens
61        {
62            if *domain == TermDomain::Original {
63                pos_map
64                    .entry(term.clone())
65                    .or_default()
66                    .push((span.0 as u32, span.1 as u32));
67            } else {
68                derived_mapping
69                    .entry(term.clone())
70                    .or_default()
71                    .insert((span.0 as u32, span.1 as u32));
72            }
73        }
74        let doc_len = token_stream.doc_len;
75        let term_freqs = token_stream.term_freqs;
76        let mut domain_doc_len = DomainLengths::from_term_freqs(&term_freqs);
77        if domain_doc_len.is_zero() {
78            domain_doc_len.add(TermDomain::Original, doc_len);
79        }
80
81        if let Some(docs) = self.docs.get_mut(index_name) {
82            if let Some(old_data) = docs.remove(doc_id) {
83                *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
84
85                let old_domain_lengths = DomainLengths::from_doc(&old_data);
86                if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
87                    old_domain_lengths.for_each_nonzero(|domain, len| {
88                        total_by_domain.add(domain, -len);
89                    });
90                }
91
92                self.index_maps_mut(index_name)
93                    .remove_doc_terms(doc_id, &old_data);
94            }
95        }
96
97        let mut writer = self.index_writer(index_name, doc_id);
98        for (term, freqs) in &term_freqs {
99            writer.add_term_frequency(term, freqs);
100        }
101
102        let doc_data = DocData {
103            content: text.to_string(),
104            doc_len,
105            term_pos: pos_map,
106            term_freqs,
107            domain_doc_len: domain_doc_len.clone(),
108            derived_terms: derived_mapping
109                .into_iter()
110                .map(|(k, v)| {
111                    let mut spans: Vec<(u32, u32)> = v.into_iter().collect();
112                    spans.sort();
113                    spans.dedup();
114                    if let Some(min_len) = spans.iter().map(|(s, e)| e - s).min() {
115                        spans.retain(|(s, e)| e - s == min_len);
116                    }
117                    (k, spans)
118                })
119                .collect(),
120        };
121
122        self.docs
123            .entry(index_name.to_string())
124            .or_default()
125            .insert(doc_id.to_string(), doc_data);
126        *self.total_lens.entry(index_name.to_string()).or_default() += doc_len;
127        let total_by_domain = self
128            .domain_total_lens
129            .entry(index_name.to_string())
130            .or_default();
131        domain_doc_len.for_each_nonzero(|domain, len| {
132            total_by_domain.add(domain, len);
133        });
134
135        self.dirty
136            .entry(index_name.to_string())
137            .or_default()
138            .insert(doc_id.to_string());
139        if let Some(deleted) = self.deleted.get_mut(index_name) {
140            deleted.remove(doc_id);
141        }
142    }
143
144    /// Remove a document and its postings from an index.
145    pub fn remove_doc(&mut self, index_name: &str, doc_id: &str) {
146        if let Some(docs) = self.docs.get_mut(index_name) {
147            if let Some(old_data) = docs.remove(doc_id) {
148                *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
149
150                let old_domain_lengths = DomainLengths::from_doc(&old_data);
151                if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
152                    old_domain_lengths.for_each_nonzero(|domain, len| {
153                        total_by_domain.add(domain, -len);
154                    });
155                }
156
157                self.index_maps_mut(index_name)
158                    .remove_doc_terms(doc_id, &old_data);
159
160                self.deleted
161                    .entry(index_name.to_string())
162                    .or_default()
163                    .insert(doc_id.to_string());
164                if let Some(dirty) = self.dirty.get_mut(index_name) {
165                    dirty.remove(doc_id);
166                }
167            }
168        }
169    }
170
171    /// Fetch raw document content by id, if present.
172    pub fn get_doc(&self, index_name: &str, doc_id: &str) -> Option<String> {
173        self.docs
174            .get(index_name)
175            .and_then(|docs| docs.get(doc_id))
176            .map(|d| d.content.clone())
177    }
178
179    /// Return and clear the sets of dirty and deleted docs for persistence.
180    pub fn take_dirty_and_deleted(&mut self) -> (Vec<DirtyDoc>, DeletedDoc) {
181        let dirty = std::mem::take(&mut self.dirty);
182        let deleted = std::mem::take(&mut self.deleted);
183
184        let mut dirty_data = Vec::new();
185        for (index_name, doc_ids) in &dirty {
186            if let Some(docs) = self.docs.get(index_name) {
187                for doc_id in doc_ids {
188                    if let Some(data) = docs.get(doc_id) {
189                        dirty_data.push((
190                            index_name.clone(),
191                            doc_id.clone(),
192                            data.content.clone(),
193                            data.doc_len,
194                        ));
195                    }
196                }
197            }
198        }
199        (dirty_data, deleted)
200    }
201
202    /// Returns true if the index has new changes awaiting persistence.
203    /// Pass `Some(name)` to check a specific index or `None` to check all.
204    pub fn has_unpersisted_changes(&self, index_name: Option<&str>) -> bool {
205        match index_name {
206            Some(name) => {
207                self.dirty.get(name).map_or(false, |s| !s.is_empty())
208                    || self.deleted.get(name).map_or(false, |s| !s.is_empty())
209            }
210            None => {
211                self.dirty.values().any(|s| !s.is_empty())
212                    || self.deleted.values().any(|s| !s.is_empty())
213            }
214        }
215    }
216
217    /// Persist the given index only if there are pending changes.
218    ///
219    /// Returns `Ok(true)` if persistence was attempted (and succeeded), `Ok(false)` if skipped.
220    /// The index is marked clean only after the provided callback returns `Ok`.
221    pub fn persist_if_dirty<E>(
222        &mut self,
223        index_name: &str,
224        mut persist: impl FnMut(SnapshotData) -> Result<(), E>,
225    ) -> Result<bool, E> {
226        if !self.has_unpersisted_changes(Some(index_name)) {
227            return Ok(false);
228        }
229
230        let Some(snapshot) = self.get_snapshot_data(index_name) else {
231            return Ok(false);
232        };
233
234        persist(snapshot)?;
235        self.dirty.remove(index_name);
236        self.deleted.remove(index_name);
237        Ok(true)
238    }
239
240    /// Get byte/UTF-16 spans for a query's terms within a document by re-tokenizing the query.
241    pub fn get_matches(&self, index_name: &str, doc_id: &str, query: &str) -> Vec<(u32, u32)> {
242        let query_terms: Vec<String> = self
243            .tokenize_query(query)
244            .into_iter()
245            .map(|t| t.term)
246            .collect();
247        self.get_matches_for_terms(index_name, doc_id, &query_terms)
248    }
249
250    /// Get spans for specific terms within a document.
251    pub fn get_matches_for_terms(
252        &self,
253        index_name: &str,
254        doc_id: &str,
255        terms: &[String],
256    ) -> Vec<(u32, u32)> {
257        let mut matches = Vec::new();
258        if let Some(docs) = self.docs.get(index_name) {
259            if let Some(doc_data) = docs.get(doc_id) {
260                for term in terms {
261                    if let Some(positions) = doc_data.term_pos.get(term) {
262                        matches.extend(positions.iter().cloned());
263                        continue;
264                    }
265                    if let Some(positions) = doc_data.derived_terms.get(term) {
266                        matches.extend(positions.iter().cloned());
267                    }
268                }
269                if !matches.is_empty() {
270                    matches = convert_spans(&doc_data.content, &matches, self.position_encoding);
271                }
272            }
273        }
274        matches.sort_by(|a, b| a.0.cmp(&b.0));
275        matches
276    }
277
278    /// Get spans for previously returned matched terms (e.g., from `search_hits`).
279    pub fn get_matches_for_matched_terms(
280        &self,
281        index_name: &str,
282        doc_id: &str,
283        terms: &[crate::types::MatchedTerm],
284    ) -> Vec<(u32, u32)> {
285        let term_strings: Vec<String> = terms.iter().map(|t| t.term.clone()).collect();
286        self.get_matches_for_terms(index_name, doc_id, &term_strings)
287    }
288
289    /// Load a snapshot into an index, expecting all auxiliary structures to be present.
290    pub fn load_snapshot(&mut self, index_name: &str, snapshot: SnapshotData) {
291        assert_eq!(
292            snapshot.version, SNAPSHOT_VERSION,
293            "snapshot version {} does not match expected {}",
294            snapshot.version, SNAPSHOT_VERSION
295        );
296        let version = {
297            let mut maps = self.index_maps_mut(index_name);
298            maps.clear(false);
299            maps.import_snapshot(snapshot);
300            maps.version
301        };
302        self.versions.insert(index_name.to_string(), version);
303        self.dirty.remove(index_name);
304        self.deleted.remove(index_name);
305    }
306
307    /// Get a serializable snapshot of the given index, including aux dictionaries/ngrams.
308    pub fn get_snapshot_data(&self, index_name: &str) -> Option<SnapshotData> {
309        self.docs.get(index_name).map(|docs| {
310            let domains = self.domains.get(index_name).cloned().unwrap_or_default();
311
312            SnapshotData {
313                version: *self.versions.get(index_name).unwrap_or(&SNAPSHOT_VERSION),
314                docs: docs.clone(),
315                total_len: *self.total_lens.get(index_name).unwrap_or(&0),
316                domain_total_len: self
317                    .domain_total_lens
318                    .get(index_name)
319                    .cloned()
320                    .unwrap_or_default(),
321                domains,
322            }
323        })
324    }
325
326    fn document_pipeline(&self) -> Pipeline {
327        if let Some(cfg) = &self.dictionary {
328            Pipeline::with_dictionary(cfg.clone())
329        } else {
330            Pipeline::document_pipeline()
331        }
332    }
333
334    pub(super) fn tokenize_query(&self, query: &str) -> Vec<Token> {
335        if let Some(cfg) = &self.dictionary {
336            Pipeline::new(DefaultTokenizer::for_queries().with_dictionary(cfg.clone()))
337                .query_tokens(query)
338                .tokens
339                .into_iter()
340                .map(|token| Token {
341                    term: token.term,
342                    start: token.span.0,
343                    end: token.span.1,
344                })
345                .collect()
346        } else {
347            Pipeline::tokenize_query(query)
348        }
349    }
350}
351
352fn convert_spans(
353    content: &str,
354    spans: &[(u32, u32)],
355    encoding: PositionEncoding,
356) -> Vec<(u32, u32)> {
357    match encoding {
358        PositionEncoding::Bytes => spans.to_vec(),
359        PositionEncoding::Utf16 => spans
360            .iter()
361            .map(|(start, end)| {
362                let s = to_utf16_index(content, *start as usize);
363                let e = to_utf16_index(content, *end as usize);
364                (s as u32, e as u32)
365            })
366            .collect(),
367    }
368}
369
370fn to_utf16_index(content: &str, byte_idx: usize) -> usize {
371    if byte_idx == 0 {
372        return 0;
373    }
374    let prefix = &content[..byte_idx.min(content.len())];
375    prefix.encode_utf16().count()
376}