use std::sync::Arc;
use ahash::AHashMap;
use parking_lot::RwLock;
use crate::analysis::analyzer::analyzer::Analyzer;
use crate::analysis::token::TokenStream;
use crate::error::Result;
#[derive(Debug)]
pub struct PerFieldAnalyzer {
default_analyzer: Arc<dyn Analyzer>,
field_analyzers: RwLock<AHashMap<String, Arc<dyn Analyzer>>>,
}
impl Clone for PerFieldAnalyzer {
fn clone(&self) -> Self {
Self {
default_analyzer: self.default_analyzer.clone(),
field_analyzers: RwLock::new(self.field_analyzers.read().clone()),
}
}
}
impl PerFieldAnalyzer {
pub fn new(default_analyzer: Arc<dyn Analyzer>) -> Self {
Self {
default_analyzer,
field_analyzers: RwLock::new(AHashMap::new()),
}
}
pub fn add_analyzer(&self, field: impl Into<String>, analyzer: Arc<dyn Analyzer>) {
self.field_analyzers.write().insert(field.into(), analyzer);
}
pub fn remove_analyzer(&self, field: &str) {
self.field_analyzers.write().remove(field);
}
pub fn get_analyzer(&self, field: &str) -> Arc<dyn Analyzer> {
let guard = self.field_analyzers.read();
guard
.get(field)
.cloned()
.unwrap_or_else(|| self.default_analyzer.clone())
}
pub fn default_analyzer(&self) -> &Arc<dyn Analyzer> {
&self.default_analyzer
}
pub fn analyze_field(&self, field: &str, text: &str) -> Result<TokenStream> {
self.get_analyzer(field).analyze(text)
}
}
impl Analyzer for PerFieldAnalyzer {
fn analyze(&self, text: &str) -> Result<TokenStream> {
self.default_analyzer.analyze(text)
}
fn name(&self) -> &'static str {
"PerFieldAnalyzer"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::analyzer::keyword::KeywordAnalyzer;
use crate::analysis::analyzer::standard::StandardAnalyzer;
#[test]
fn test_per_field_analyzer() {
let analyzer = PerFieldAnalyzer::new(Arc::new(StandardAnalyzer::new().unwrap()));
analyzer.add_analyzer("id", Arc::new(KeywordAnalyzer::new()));
analyzer.add_analyzer("category", Arc::new(KeywordAnalyzer::new()));
let text = "Hello World";
let tokens: Vec<_> = analyzer.analyze_field("title", text).unwrap().collect();
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "hello");
assert_eq!(tokens[1].text, "world");
let tokens: Vec<_> = analyzer.analyze_field("id", text).unwrap().collect();
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "Hello World");
let tokens: Vec<_> = analyzer.analyze_field("category", text).unwrap().collect();
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "Hello World");
}
#[test]
fn test_default_analyzer_when_field_not_configured() {
let analyzer = PerFieldAnalyzer::new(Arc::new(StandardAnalyzer::new().unwrap()));
let text = "Hello World";
let tokens: Vec<_> = analyzer
.analyze_field("unknown_field", text)
.unwrap()
.collect();
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "hello");
assert_eq!(tokens[1].text, "world");
}
#[test]
fn test_as_analyzer_trait() {
let analyzer = PerFieldAnalyzer::new(Arc::new(StandardAnalyzer::new().unwrap()));
analyzer.add_analyzer("id", Arc::new(KeywordAnalyzer::new()));
let text = "Hello World";
let tokens: Vec<_> = analyzer.analyze(text).unwrap().collect();
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "hello");
assert_eq!(tokens[1].text, "world");
}
}