use std::sync::Arc;
use levenshtein_automata::{Distance, LevenshteinAutomatonBuilder, DFA};
use once_cell::sync::OnceCell;
use tantivy_fst::Automaton;
use crate::query::phrase_query::scoring_utils::HighlightSink;
use crate::query::{AutomatonWeight, EnableScoring, Query, Weight};
use crate::schema::{Term, Type};
use crate::LucivyError::InvalidArgument;
pub(crate) struct DfaWrapper(pub DFA);
impl Automaton for DfaWrapper {
type State = u32;
fn start(&self) -> Self::State {
self.0.initial_state()
}
fn is_match(&self, state: &Self::State) -> bool {
match self.0.distance(*state) {
Distance::Exact(_) => true,
Distance::AtLeast(_) => false,
}
}
fn can_match(&self, state: &u32) -> bool {
*state != levenshtein_automata::SINK_STATE
}
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
self.0.transition(*state, byte)
}
}
#[derive(Clone)]
pub struct FuzzyTermQuery {
term: Term,
distance: u8,
transposition_cost_one: bool,
prefix: bool,
highlight_sink: Option<Arc<HighlightSink>>,
highlight_field_name: String,
}
impl std::fmt::Debug for FuzzyTermQuery {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("FuzzyTermQuery")
.field("term", &self.term)
.field("distance", &self.distance)
.field("transposition_cost_one", &self.transposition_cost_one)
.field("prefix", &self.prefix)
.finish()
}
}
impl FuzzyTermQuery {
pub fn new(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery {
FuzzyTermQuery {
term,
distance,
transposition_cost_one,
prefix: false,
highlight_sink: None,
highlight_field_name: String::new(),
}
}
pub fn new_prefix(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery {
FuzzyTermQuery {
term,
distance,
transposition_cost_one,
prefix: true,
highlight_sink: None,
highlight_field_name: String::new(),
}
}
pub fn with_highlight_sink(mut self, sink: Arc<HighlightSink>, field_name: String) -> Self {
self.highlight_sink = Some(sink);
self.highlight_field_name = field_name;
self
}
fn specialized_weight(&self) -> crate::Result<AutomatonWeight<DfaWrapper>> {
static AUTOMATON_BUILDER: [[OnceCell<LevenshteinAutomatonBuilder>; 2]; 3] = [
[OnceCell::new(), OnceCell::new()],
[OnceCell::new(), OnceCell::new()],
[OnceCell::new(), OnceCell::new()],
];
let automaton_builder = AUTOMATON_BUILDER
.get(self.distance as usize)
.ok_or_else(|| {
InvalidArgument(format!(
"Levenshtein distance of {} is not allowed. Choose a value less than {}",
self.distance,
AUTOMATON_BUILDER.len()
))
})?
.get(self.transposition_cost_one as usize)
.unwrap()
.get_or_init(|| {
LevenshteinAutomatonBuilder::new(self.distance, self.transposition_cost_one)
});
let term_value = self.term.value();
let term_text = if term_value.typ() == Type::Json {
if let Some(json_path_type) = term_value.json_path_type() {
if json_path_type != Type::Str {
return Err(InvalidArgument(format!(
"The fuzzy term query requires a string path type for a json term. Found \
{json_path_type:?}"
)));
}
}
std::str::from_utf8(self.term.serialized_value_bytes()).map_err(|_| {
InvalidArgument(
"Failed to convert json term value bytes to utf8 string.".to_string(),
)
})?
} else {
term_value.as_str().ok_or_else(|| {
InvalidArgument("The fuzzy term query requires a string term.".to_string())
})?
};
let automaton = if self.prefix {
automaton_builder.build_prefix_dfa(term_text)
} else {
automaton_builder.build_dfa(term_text)
};
let mut weight = if let Some((json_path_bytes, _)) = term_value.as_json() {
AutomatonWeight::new_for_json_path(
self.term.field(),
DfaWrapper(automaton),
json_path_bytes,
)
} else {
AutomatonWeight::new(
self.term.field(),
DfaWrapper(automaton),
)
};
if let Some(ref sink) = self.highlight_sink {
weight = weight.with_highlight_sink(Arc::clone(sink), self.highlight_field_name.clone());
}
Ok(weight)
}
}
impl Query for FuzzyTermQuery {
fn weight(&self, _enable_scoring: EnableScoring<'_>) -> crate::Result<Box<dyn Weight>> {
Ok(Box::new(self.specialized_weight()?))
}
}
#[cfg(test)]
mod test {
use super::FuzzyTermQuery;
use crate::collector::{Count, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::QueryParser;
use crate::schema::{Schema, STORED, TEXT};
use crate::{assert_nearly_equals, Index, IndexWriter, LucivyDocument, Term};
#[test]
pub fn test_fuzzy_json_path() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let attributes = schema_builder.add_json_field("attributes", TEXT | STORED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer_for_tests()?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
let doc = LucivyDocument::parse_json(
&schema,
r#"{
"attributes": {
"a": "japan"
}
}"#,
)?;
index_writer.add_document(doc)?;
let doc = LucivyDocument::parse_json(
&schema,
r#"{
"attributes": {
"aa": "japan"
}
}"#,
)?;
index_writer.add_document(doc)?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![attributes]);
let get_json_path_term = |query: &str| -> crate::Result<Term> {
let query = query_parser.parse_query(query)?;
let mut terms = Vec::new();
query.query_terms(&mut |term, _| {
terms.push(term.clone());
});
Ok(terms[0].clone())
};
{
let term = get_json_path_term("attributes.aa:japan")?;
let fuzzy_query = FuzzyTermQuery::new(term, 2, true);
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 1, "Expected the second document");
}
{
let term = get_json_path_term("attributes.a:japon")?;
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 0, "Expected the first document");
}
{
let term = get_json_path_term("attributes.a:jap")?;
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 0, "Expected no document");
}
Ok(())
}
#[test]
pub fn test_fuzzy_term() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let country_field = schema_builder.add_text_field("country", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer: IndexWriter = index.writer_for_tests()?;
index_writer.add_document(doc!(
country_field => "japan",
))?;
index_writer.add_document(doc!(
country_field => "korea",
))?;
index_writer.commit()?;
}
let reader = index.reader()?;
let searcher = reader.searcher();
{
let term = Term::from_field_text(country_field, "japon");
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
}
{
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 0, "Expected no document");
}
{
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new_prefix(term, 1, true);
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
}
Ok(())
}
#[test]
pub fn test_fuzzy_term_transposition_cost_one() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let country_field = schema_builder.add_text_field("country", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
index_writer.add_document(doc!(country_field => "japan"))?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let term_jaapn = Term::from_field_text(country_field, "jaapn");
{
let fuzzy_query_transposition = FuzzyTermQuery::new(term_jaapn.clone(), 1, true);
let count = searcher.search(&fuzzy_query_transposition, &Count)?;
assert_eq!(count, 1);
}
{
let fuzzy_query_transposition = FuzzyTermQuery::new(term_jaapn, 1, false);
let count = searcher.search(&fuzzy_query_transposition, &Count)?;
assert_eq!(count, 0);
}
Ok(())
}
}