use std::collections::{HashMap, HashSet};
use super::{
pipeline::{DefaultTokenizer, Pipeline},
tokenizer::Token,
types::{
DocData, DomainLengths, DomainSnapshot, InMemoryIndex, PipelineToken, PositionEncoding,
SNAPSHOT_VERSION, SnapshotData, TermDomain, TokenStream,
},
};
type DirtyDoc = (String, String, String, i64);
type DeletedDoc = HashMap<String, HashSet<String>>;
impl InMemoryIndex {
pub fn with_position_encoding(encoding: PositionEncoding) -> Self {
let mut index = Self::default();
index.position_encoding = encoding;
index
}
pub fn with_dictionary_config(dictionary: crate::tokenizer::DictionaryConfig) -> Self {
let mut index = Self::default();
index.dictionary = Some(dictionary);
index
}
pub fn set_position_encoding(&mut self, encoding: PositionEncoding) {
self.position_encoding = encoding;
}
pub fn set_dictionary_config(
&mut self,
dictionary: Option<crate::tokenizer::DictionaryConfig>,
) {
self.dictionary = dictionary;
}
pub fn add_doc(&mut self, index_name: &str, doc_id: &str, text: &str, index: bool) {
let token_stream = if index {
self.document_pipeline().document_tokens(text)
} else {
TokenStream {
tokens: Vec::new(),
term_freqs: HashMap::new(),
doc_len: 0,
}
};
let mut pos_map: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
let mut derived_mapping: HashMap<String, HashSet<(u32, u32)>> = HashMap::new();
for PipelineToken {
term, span, domain, ..
} in &token_stream.tokens
{
if *domain == TermDomain::Original {
pos_map
.entry(term.clone())
.or_default()
.push((span.0 as u32, span.1 as u32));
} else {
derived_mapping
.entry(term.clone())
.or_default()
.insert((span.0 as u32, span.1 as u32));
}
}
let doc_len = token_stream.doc_len;
let term_freqs = token_stream.term_freqs;
let mut domain_doc_len = DomainLengths::from_term_freqs(&term_freqs);
if domain_doc_len.is_zero() {
domain_doc_len.add(TermDomain::Original, doc_len);
}
if let Some(docs) = self.docs.get_mut(index_name) {
if let Some(old_data) = docs.remove(doc_id) {
*self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
let old_domain_lengths = DomainLengths::from_doc(&old_data);
if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
old_domain_lengths.for_each_nonzero(|domain, len| {
total_by_domain.add(domain, -len);
});
}
self.index_maps_mut(index_name)
.remove_doc_terms(doc_id, &old_data);
}
}
let mut writer = self.index_writer(index_name, doc_id);
for (term, freqs) in &term_freqs {
writer.add_term_frequency(term, freqs);
}
let doc_data = DocData {
content: text.to_string(),
doc_len,
term_pos: pos_map,
term_freqs,
domain_doc_len: domain_doc_len.clone(),
derived_terms: derived_mapping
.into_iter()
.map(|(k, v)| {
let mut spans: Vec<(u32, u32)> = v.into_iter().collect();
spans.sort();
spans.dedup();
if let Some(min_len) = spans.iter().map(|(s, e)| e - s).min() {
spans.retain(|(s, e)| e - s == min_len);
}
(k, spans)
})
.collect(),
};
self.docs
.entry(index_name.to_string())
.or_default()
.insert(doc_id.to_string(), doc_data);
*self.total_lens.entry(index_name.to_string()).or_default() += doc_len;
let total_by_domain = self
.domain_total_lens
.entry(index_name.to_string())
.or_default();
domain_doc_len.for_each_nonzero(|domain, len| {
total_by_domain.add(domain, len);
});
self.dirty
.entry(index_name.to_string())
.or_default()
.insert(doc_id.to_string());
if let Some(deleted) = self.deleted.get_mut(index_name) {
deleted.remove(doc_id);
}
}
pub fn remove_doc(&mut self, index_name: &str, doc_id: &str) {
if let Some(docs) = self.docs.get_mut(index_name) {
if let Some(old_data) = docs.remove(doc_id) {
*self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
let old_domain_lengths = DomainLengths::from_doc(&old_data);
if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
old_domain_lengths.for_each_nonzero(|domain, len| {
total_by_domain.add(domain, -len);
});
}
self.index_maps_mut(index_name)
.remove_doc_terms(doc_id, &old_data);
self.deleted
.entry(index_name.to_string())
.or_default()
.insert(doc_id.to_string());
if let Some(dirty) = self.dirty.get_mut(index_name) {
dirty.remove(doc_id);
}
}
}
}
pub fn get_doc(&self, index_name: &str, doc_id: &str) -> Option<String> {
self.docs
.get(index_name)
.and_then(|docs| docs.get(doc_id))
.map(|d| d.content.clone())
}
pub fn take_dirty_and_deleted(&mut self) -> (Vec<DirtyDoc>, DeletedDoc) {
let dirty = std::mem::take(&mut self.dirty);
let deleted = std::mem::take(&mut self.deleted);
let mut dirty_data = Vec::new();
for (index_name, doc_ids) in &dirty {
if let Some(docs) = self.docs.get(index_name) {
for doc_id in doc_ids {
if let Some(data) = docs.get(doc_id) {
dirty_data.push((
index_name.clone(),
doc_id.clone(),
data.content.clone(),
data.doc_len,
));
}
}
}
}
(dirty_data, deleted)
}
pub fn get_matches(&self, index_name: &str, doc_id: &str, query: &str) -> Vec<(u32, u32)> {
let query_terms: Vec<String> = self
.tokenize_query(query)
.into_iter()
.map(|t| t.term)
.collect();
self.get_matches_for_terms(index_name, doc_id, &query_terms)
}
pub fn get_matches_for_terms(
&self,
index_name: &str,
doc_id: &str,
terms: &[String],
) -> Vec<(u32, u32)> {
let mut matches = Vec::new();
if let Some(docs) = self.docs.get(index_name) {
if let Some(doc_data) = docs.get(doc_id) {
for term in terms {
if let Some(positions) = doc_data.term_pos.get(term) {
matches.extend(positions.iter().cloned());
continue;
}
if let Some(positions) = doc_data.derived_terms.get(term) {
matches.extend(positions.iter().cloned());
}
}
if !matches.is_empty() {
matches = convert_spans(&doc_data.content, &matches, self.position_encoding);
}
}
}
matches.sort_by(|a, b| a.0.cmp(&b.0));
matches
}
pub fn get_matches_for_matched_terms(
&self,
index_name: &str,
doc_id: &str,
terms: &[crate::types::MatchedTerm],
) -> Vec<(u32, u32)> {
let term_strings: Vec<String> = terms.iter().map(|t| t.term.clone()).collect();
self.get_matches_for_terms(index_name, doc_id, &term_strings)
}
pub fn load_snapshot(&mut self, index_name: &str, snapshot: SnapshotData) {
let version = {
let mut maps = self.index_maps_mut(index_name);
maps.clear(false);
maps.import_snapshot(snapshot);
maps.version
};
self.versions.insert(index_name.to_string(), version);
}
pub fn get_snapshot_data(&self, index_name: &str) -> Option<SnapshotData> {
self.docs.get(index_name).map(|docs| {
let domains = self
.domains
.get(index_name)
.cloned()
.unwrap_or_default()
.into_iter()
.map(|(domain, data)| {
(
domain,
DomainSnapshot {
term_dict: data.term_dict,
ngram_index: data.ngram_index,
},
)
})
.collect();
SnapshotData {
version: *self.versions.get(index_name).unwrap_or(&SNAPSHOT_VERSION),
docs: docs.clone(),
domains,
}
})
}
fn document_pipeline(&self) -> Pipeline {
if let Some(cfg) = &self.dictionary {
Pipeline::with_dictionary(cfg.clone())
} else {
Pipeline::document_pipeline()
}
}
pub(super) fn tokenize_query(&self, query: &str) -> Vec<Token> {
if let Some(cfg) = &self.dictionary {
Pipeline::new(DefaultTokenizer::for_queries().with_dictionary(cfg.clone()))
.query_tokens(query)
.tokens
.into_iter()
.map(|token| Token {
term: token.term,
start: token.span.0,
end: token.span.1,
})
.collect()
} else {
Pipeline::tokenize_query(query)
}
}
}
fn convert_spans(
content: &str,
spans: &[(u32, u32)],
encoding: PositionEncoding,
) -> Vec<(u32, u32)> {
match encoding {
PositionEncoding::Bytes => spans.to_vec(),
PositionEncoding::Utf16 => spans
.iter()
.map(|(start, end)| {
let s = to_utf16_index(content, *start as usize);
let e = to_utf16_index(content, *end as usize);
(s as u32, e as u32)
})
.collect(),
}
}
fn to_utf16_index(content: &str, byte_idx: usize) -> usize {
if byte_idx == 0 {
return 0;
}
let prefix = &content[..byte_idx.min(content.len())];
prefix.encode_utf16().count()
}