memory-indexer 0.3.1

An in-memory full-text fuzzy search indexer.
Documentation
use std::collections::HashSet;

mod derivers;
mod tokenizers;

pub use derivers::{NoopNgramDeriver, PassthroughNormalizer, PinyinVariantDeriver};
pub use tokenizers::{DefaultScriptSegmenter, DefaultTokenizer};

use super::{
    tokenizer::{
        DictionaryConfig, OffsetMap, SegmentScript, Token, cjk_spans, contains_chinese_chars,
        is_cjk_char, should_derive_pinyin_for_span,
    },
    types::{NormalizedTerm, PipelineToken, Segment, TermDomain, TokenDraft, TokenStream},
};

const MAX_CJK_SPAN_DERIVATION_CHARS: usize = 32;

pub struct Pipeline {
    segmenter: DefaultScriptSegmenter,
    tokenizer: DefaultTokenizer,
    normalizer: PassthroughNormalizer,
    variant_deriver: PinyinVariantDeriver,
    ngram_deriver: NoopNgramDeriver,
}

#[derive(Clone, Copy)]
pub struct PipelineConfig {
    pub enable_variants: bool,
}

impl PipelineConfig {
    pub fn document() -> Self {
        Self {
            enable_variants: true,
        }
    }

    pub fn query() -> Self {
        Self {
            enable_variants: false,
        }
    }
}

impl Pipeline {
    pub fn document_pipeline() -> Self {
        Self::new(DefaultTokenizer::for_documents())
    }

    pub fn query_pipeline() -> Self {
        Self::new(DefaultTokenizer::for_queries())
    }

    pub fn new(tokenizer: DefaultTokenizer) -> Self {
        Self {
            segmenter: DefaultScriptSegmenter,
            tokenizer,
            normalizer: PassthroughNormalizer,
            variant_deriver: PinyinVariantDeriver::default(),
            ngram_deriver: NoopNgramDeriver,
        }
    }

    pub fn with_dictionary(dictionary: DictionaryConfig) -> Self {
        Self::new(DefaultTokenizer::for_documents().with_dictionary(dictionary))
    }

    pub fn tokenize_query(text: &str) -> Vec<Token> {
        Self::query_pipeline()
            .query_tokens(text)
            .tokens
            .into_iter()
            .map(|token| Token {
                term: token.term,
                start: token.span.0,
                end: token.span.1,
            })
            .collect()
    }

    pub fn document_tokens(&self, text: &str) -> TokenStream {
        self.run(text, PipelineConfig::document())
    }

    pub fn query_tokens(&self, text: &str) -> TokenStream {
        self.run(text, PipelineConfig::query())
    }

    fn run(&self, text: &str, config: PipelineConfig) -> TokenStream {
        let mut drafts = Vec::new();
        let mut seen = HashSet::new();
        for segment in self.segmenter.segment(text) {
            self.tokenizer
                .tokenize_segment(&segment, &mut drafts, &mut seen);
        }

        let mut tokens = Vec::new();
        let mut covered_cjk_spans: HashSet<(usize, usize)> = HashSet::new();

        let mut doc_len: i64 = 0;

        for draft in &drafts {
            let normalized = self.normalizer.normalize(draft);
            doc_len += normalized.len() as i64;

            for norm in normalized {
                let span = norm.span;
                push_token(
                    &mut tokens,
                    PipelineToken {
                        term: norm.term.clone(),
                        span,
                        domain: TermDomain::Original,
                        base_term: norm.term.clone(),
                    },
                );

                if norm.term.chars().all(is_cjk_char) {
                    covered_cjk_spans.insert(span);
                }

                if config.enable_variants
                    && contains_chinese_chars(&norm.term)
                    && should_derive_pinyin_for_span(text, span.0, span.1)
                {
                    self.derive_variants(&norm, &mut tokens);
                }
            }
        }

        if config.enable_variants {
            for (start, end) in cjk_spans(text) {
                if covered_cjk_spans.contains(&(start, end)) {
                    continue;
                }
                let span = (start, end);
                if text[start..end].chars().count() > MAX_CJK_SPAN_DERIVATION_CHARS {
                    continue;
                }
                if !should_derive_pinyin_for_span(text, start, end) {
                    continue;
                }
                let term = text[start..end].to_string();
                let norm = NormalizedTerm {
                    term,
                    span,
                    script: SegmentScript::Han,
                    mapping: OffsetMap::identity(span),
                };
                self.derive_variants(&norm, &mut tokens);
            }
        }

        TokenStream { tokens, doc_len }
    }

    fn derive_variants(&self, term: &NormalizedTerm, tokens: &mut Vec<PipelineToken>) {
        for variant in self.variant_deriver.derive(term) {
            push_token(tokens, variant.clone());

            for ngram in self.ngram_deriver.derive_ngrams(&variant) {
                push_token(tokens, ngram);
            }
        }
    }
}

fn push_token(tokens: &mut Vec<PipelineToken>, token: PipelineToken) {
    tokens.push(token);
}