use std::collections::HashSet;
use std::sync::Arc;
use crate::analysis::analyzer::analyzer::Analyzer;
use crate::analysis::analyzer::keyword::KeywordAnalyzer;
use crate::analysis::analyzer::language::english::EnglishAnalyzer;
use crate::analysis::analyzer::language::japanese::JapaneseAnalyzer;
use crate::analysis::analyzer::noop::NoOpAnalyzer;
use crate::analysis::analyzer::pipeline::PipelineAnalyzer;
use crate::analysis::analyzer::simple::SimpleAnalyzer;
use crate::analysis::analyzer::standard::StandardAnalyzer;
use crate::analysis::char_filter::CharFilter;
use crate::analysis::char_filter::japanese_iteration_mark::JapaneseIterationMarkCharFilter;
use crate::analysis::char_filter::mapping::MappingCharFilter;
use crate::analysis::char_filter::pattern_replace::PatternReplaceCharFilter;
use crate::analysis::char_filter::unicode_normalize::{
NormalizationForm, UnicodeNormalizationCharFilter,
};
use crate::analysis::token_filter::Filter;
use crate::analysis::token_filter::boost::BoostFilter;
use crate::analysis::token_filter::flatten_graph::FlattenGraphFilter;
use crate::analysis::token_filter::limit::LimitFilter;
use crate::analysis::token_filter::lowercase::LowercaseFilter;
use crate::analysis::token_filter::remove_empty::RemoveEmptyFilter;
use crate::analysis::token_filter::stem::{StemFilter, identity::IdentityStemmer};
use crate::analysis::token_filter::stop::StopFilter;
use crate::analysis::token_filter::strip::StripFilter;
use crate::analysis::tokenizer::Tokenizer;
use crate::analysis::tokenizer::lindera::LinderaTokenizer;
use crate::analysis::tokenizer::ngram::NgramTokenizer;
use crate::analysis::tokenizer::regex::RegexTokenizer;
use crate::analysis::tokenizer::unicode_word::UnicodeWordTokenizer;
use crate::analysis::tokenizer::whitespace::WhitespaceTokenizer;
use crate::analysis::tokenizer::whole::WholeTokenizer;
use crate::engine::schema::analyzer::{
AnalyzerDefinition, CharFilterConfig, TokenFilterConfig, TokenizerConfig,
};
use crate::error::{LaurusError, Result};
pub fn create_analyzer_by_name(name: &str) -> Result<Arc<dyn Analyzer>> {
match name {
"standard" => Ok(Arc::new(StandardAnalyzer::new()?)),
"keyword" => Ok(Arc::new(KeywordAnalyzer::new())),
"english" => Ok(Arc::new(EnglishAnalyzer::new()?)),
"japanese" => Ok(Arc::new(JapaneseAnalyzer::new()?)),
"simple" => Ok(Arc::new(SimpleAnalyzer::new(Arc::new(
RegexTokenizer::new()?,
)))),
"noop" => Ok(Arc::new(NoOpAnalyzer::new())),
unknown => Err(LaurusError::invalid_argument(format!(
"Unknown analyzer: {unknown}"
))),
}
}
pub fn create_analyzer_from_definition(
name: &str,
definition: &AnalyzerDefinition,
) -> Result<Arc<dyn Analyzer>> {
let tokenizer: Arc<dyn Tokenizer> = match &definition.tokenizer {
TokenizerConfig::Whitespace => Arc::new(WhitespaceTokenizer::new()),
TokenizerConfig::UnicodeWord => Arc::new(UnicodeWordTokenizer::new()),
TokenizerConfig::Regex { pattern, gaps } => {
if *gaps {
Arc::new(RegexTokenizer::with_gaps(pattern)?)
} else {
Arc::new(RegexTokenizer::with_pattern(pattern)?)
}
}
TokenizerConfig::Ngram { min_gram, max_gram } => {
Arc::new(NgramTokenizer::new(*min_gram, *max_gram)?)
}
TokenizerConfig::Lindera {
mode,
dict,
user_dict,
} => Arc::new(LinderaTokenizer::new(mode, dict, user_dict.as_deref())?),
TokenizerConfig::Whole => Arc::new(WholeTokenizer::new()),
};
let mut pipeline = PipelineAnalyzer::new(tokenizer).with_name(name.to_string());
for cf_config in &definition.char_filters {
let cf: Arc<dyn CharFilter> = match cf_config {
CharFilterConfig::UnicodeNormalization { form } => {
let nf = match form.to_lowercase().as_str() {
"nfc" => NormalizationForm::NFC,
"nfd" => NormalizationForm::NFD,
"nfkc" => NormalizationForm::NFKC,
"nfkd" => NormalizationForm::NFKD,
_ => {
return Err(LaurusError::invalid_argument(format!(
"Unknown normalization form: {form}"
)));
}
};
Arc::new(UnicodeNormalizationCharFilter::new(nf))
}
CharFilterConfig::PatternReplace {
pattern,
replacement,
} => Arc::new(PatternReplaceCharFilter::new(pattern, replacement)?),
CharFilterConfig::Mapping { mapping } => {
Arc::new(MappingCharFilter::new(mapping.clone())?)
}
CharFilterConfig::JapaneseIterationMark { kanji, kana } => {
Arc::new(JapaneseIterationMarkCharFilter::new(*kanji, *kana))
}
};
pipeline = pipeline.add_char_filter(cf);
}
for tf_config in &definition.token_filters {
let tf: Arc<dyn Filter> = match tf_config {
TokenFilterConfig::Lowercase => Arc::new(LowercaseFilter::new()),
TokenFilterConfig::Stop { words } => {
if let Some(word_list) = words {
let set: HashSet<String> = word_list.iter().cloned().collect();
Arc::new(StopFilter::with_stop_words(set))
} else {
Arc::new(StopFilter::new())
}
}
TokenFilterConfig::Stem { stem_type } => {
let stemmer_name = stem_type.as_deref().unwrap_or("porter");
match stemmer_name {
"porter" => Arc::new(StemFilter::new()),
"simple" => Arc::new(StemFilter::simple()),
"identity" => {
Arc::new(StemFilter::with_stemmer(Box::new(IdentityStemmer::new())))
}
_ => {
return Err(LaurusError::invalid_argument(format!(
"Unknown stemmer type: {stemmer_name}"
)));
}
}
}
TokenFilterConfig::Boost { boost } => Arc::new(BoostFilter::new(*boost)),
TokenFilterConfig::Limit { limit } => Arc::new(LimitFilter::new(*limit)),
TokenFilterConfig::Strip => Arc::new(StripFilter::new()),
TokenFilterConfig::RemoveEmpty => Arc::new(RemoveEmptyFilter::new()),
TokenFilterConfig::FlattenGraph => Arc::new(FlattenGraphFilter::new()),
};
pipeline = pipeline.add_filter(tf);
}
Ok(Arc::new(pipeline))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_standard() {
let analyzer = create_analyzer_by_name("standard").unwrap();
assert_eq!(analyzer.name(), "standard");
}
#[test]
fn test_create_keyword() {
let analyzer = create_analyzer_by_name("keyword").unwrap();
assert_eq!(analyzer.name(), "keyword");
}
#[test]
fn test_create_english() {
let analyzer = create_analyzer_by_name("english").unwrap();
assert_eq!(analyzer.name(), "english");
}
#[test]
fn test_create_japanese() {
let analyzer = create_analyzer_by_name("japanese").unwrap();
assert_eq!(analyzer.name(), "japanese");
}
#[test]
fn test_create_simple() {
let analyzer = create_analyzer_by_name("simple").unwrap();
assert_eq!(analyzer.name(), "simple");
}
#[test]
fn test_create_noop() {
let analyzer = create_analyzer_by_name("noop").unwrap();
assert_eq!(analyzer.name(), "noop");
}
#[test]
fn test_unknown_returns_error() {
let result = create_analyzer_by_name("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_create_from_definition_whitespace_lowercase() {
let def = AnalyzerDefinition {
char_filters: vec![],
tokenizer: TokenizerConfig::Whitespace,
token_filters: vec![TokenFilterConfig::Lowercase],
};
let analyzer = create_analyzer_from_definition("my_ws", &def).unwrap();
assert_eq!(analyzer.name(), "my_ws");
let tokens: Vec<_> = analyzer.analyze("Hello World").unwrap().collect();
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "hello");
assert_eq!(tokens[1].text, "world");
}
#[test]
fn test_create_from_definition_with_stop_words() {
let def = AnalyzerDefinition {
char_filters: vec![],
tokenizer: TokenizerConfig::Regex {
pattern: r"\w+".into(),
gaps: false,
},
token_filters: vec![
TokenFilterConfig::Lowercase,
TokenFilterConfig::Stop {
words: Some(vec!["the".into(), "a".into()]),
},
],
};
let analyzer = create_analyzer_from_definition("custom_stop", &def).unwrap();
let tokens: Vec<_> = analyzer.analyze("The quick brown fox").unwrap().collect();
assert_eq!(tokens.len(), 3);
}
#[test]
fn test_create_from_definition_with_char_filter() {
let def = AnalyzerDefinition {
char_filters: vec![CharFilterConfig::UnicodeNormalization {
form: "nfkc".into(),
}],
tokenizer: TokenizerConfig::Whitespace,
token_filters: vec![],
};
let analyzer = create_analyzer_from_definition("nfkc_analyzer", &def).unwrap();
let tokens: Vec<_> = analyzer.analyze("\u{ff21}").unwrap().collect();
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "A");
}
#[test]
fn test_create_from_definition_whole_tokenizer() {
let def = AnalyzerDefinition {
char_filters: vec![],
tokenizer: TokenizerConfig::Whole,
token_filters: vec![],
};
let analyzer = create_analyzer_from_definition("exact", &def).unwrap();
let tokens: Vec<_> = analyzer.analyze("Hello World").unwrap().collect();
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "Hello World");
}
}