use std::clone::Clone;
use std::sync::Arc;
use tantivy_fst::Regex;
use crate::error::LucivyError;
use crate::query::phrase_query::scoring_utils::HighlightSink;
use crate::query::{AutomatonWeight, EnableScoring, Query, Weight};
use crate::schema::Field;
#[derive(Debug, Clone)]
pub struct RegexQuery {
regex: Arc<Regex>,
field: Field,
highlight_sink: Option<Arc<HighlightSink>>,
highlight_field_name: String,
}
impl RegexQuery {
pub fn from_pattern(regex_pattern: &str, field: Field) -> crate::Result<Self> {
let regex = Regex::new(regex_pattern)
.map_err(|err| LucivyError::InvalidArgument(format!("RegexQueryError: {err}")))?;
Ok(RegexQuery::from_regex(regex, field))
}
pub fn from_regex<T: Into<Arc<Regex>>>(regex: T, field: Field) -> Self {
RegexQuery {
regex: regex.into(),
field,
highlight_sink: None,
highlight_field_name: String::new(),
}
}
pub fn with_highlight_sink(mut self, sink: Arc<HighlightSink>, field_name: String) -> Self {
self.highlight_sink = Some(sink);
self.highlight_field_name = field_name;
self
}
fn specialized_weight(&self) -> AutomatonWeight<Regex> {
let mut weight = AutomatonWeight::new(self.field, self.regex.clone());
if let Some(ref sink) = self.highlight_sink {
weight = weight.with_highlight_sink(Arc::clone(sink), self.highlight_field_name.clone());
}
weight
}
}
impl Query for RegexQuery {
fn weight(&self, _enabled_scoring: EnableScoring<'_>) -> crate::Result<Box<dyn Weight>> {
Ok(Box::new(self.specialized_weight()))
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use tantivy_fst::Regex;
use super::RegexQuery;
use crate::collector::TopDocs;
use crate::schema::{Field, Schema, TEXT};
use crate::{assert_nearly_equals, Index, IndexReader, IndexWriter};
fn build_test_index() -> crate::Result<(IndexReader, Field)> {
let mut schema_builder = Schema::builder();
let country_field = schema_builder.add_text_field("country", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
index_writer.add_document(doc!(
country_field => "japan",
))?;
index_writer.add_document(doc!(
country_field => "korea",
))?;
index_writer.commit()?;
}
let reader = index.reader()?;
Ok((reader, country_field))
}
fn verify_regex_query(
query_matching_one: RegexQuery,
query_matching_zero: RegexQuery,
reader: IndexReader,
) {
let searcher = reader.searcher();
{
let scored_docs = searcher
.search(
&query_matching_one,
&TopDocs::with_limit(2).order_by_score(),
)
.unwrap();
assert_eq!(scored_docs.len(), 1, "Expected only 1 document");
let (score, _) = scored_docs[0];
assert_nearly_equals!(1.0, score);
}
let top_docs = searcher
.search(
&query_matching_zero,
&TopDocs::with_limit(2).order_by_score(),
)
.unwrap();
assert!(top_docs.is_empty(), "Expected ZERO document");
}
#[test]
pub fn test_regex_query() -> crate::Result<()> {
let (reader, field) = build_test_index()?;
let matching_one = RegexQuery::from_pattern("jap[ao]n", field)?;
let matching_zero = RegexQuery::from_pattern("jap[A-Z]n", field)?;
verify_regex_query(matching_one, matching_zero, reader);
Ok(())
}
#[test]
pub fn test_construct_from_regex() -> crate::Result<()> {
let (reader, field) = build_test_index()?;
let matching_one = RegexQuery::from_regex(Regex::new("jap[ao]n").unwrap(), field);
let matching_zero = RegexQuery::from_regex(Regex::new("jap[A-Z]n").unwrap(), field);
verify_regex_query(matching_one, matching_zero, reader);
Ok(())
}
#[test]
pub fn test_construct_from_reused_regex() -> crate::Result<()> {
let r1 = Arc::new(Regex::new("jap[ao]n").unwrap());
let r2 = Arc::new(Regex::new("jap[A-Z]n").unwrap());
let (reader, field) = build_test_index()?;
let matching_one = RegexQuery::from_regex(r1.clone(), field);
let matching_zero = RegexQuery::from_regex(r2.clone(), field);
verify_regex_query(matching_one, matching_zero, reader.clone());
let matching_one = RegexQuery::from_regex(r1, field);
let matching_zero = RegexQuery::from_regex(r2, field);
verify_regex_query(matching_one, matching_zero, reader);
Ok(())
}
#[test]
pub fn test_pattern_error() {
let (_reader, field) = build_test_index().unwrap();
match RegexQuery::from_pattern(r"(foo", field) {
Err(crate::LucivyError::InvalidArgument(msg)) => {
assert!(msg.contains("error: unclosed group"))
}
res => panic!("unexpected result: {res:?}"),
}
}
}