use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::lexical::index::inverted::core::terms::{TermDictionaryAccess, TermsEnum};
use crate::lexical::index::inverted::reader::InvertedIndexReader;
use crate::lexical::query::Query;
use crate::lexical::query::matcher::Matcher;
use crate::lexical::query::multi_term::{MultiTermQuery, RewriteMethod};
use crate::lexical::query::scorer::Scorer;
use crate::lexical::reader::LexicalIndexReader;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefixQuery {
field: String,
prefix: String,
boost: f32,
rewrite_method: RewriteMethod,
}
impl PrefixQuery {
pub fn new<F: Into<String>, P: Into<String>>(field: F, prefix: P) -> Self {
PrefixQuery {
field: field.into(),
prefix: prefix.into(),
boost: 1.0,
rewrite_method: RewriteMethod::default(),
}
}
pub fn with_boost(mut self, boost: f32) -> Self {
self.boost = boost;
self
}
pub fn with_rewrite_method(mut self, rewrite_method: RewriteMethod) -> Self {
self.rewrite_method = rewrite_method;
self
}
pub fn prefix(&self) -> &str {
&self.prefix
}
pub fn field(&self) -> &str {
&self.field
}
pub fn rewrite_method(&self) -> RewriteMethod {
self.rewrite_method
}
}
impl MultiTermQuery for PrefixQuery {
fn field(&self) -> &str {
&self.field
}
fn rewrite_method(&self) -> RewriteMethod {
self.rewrite_method
}
fn get_terms_enum(
&self,
reader: &dyn LexicalIndexReader,
) -> Result<Option<Box<dyn TermsEnum>>> {
if let Some(inverted_reader) = reader.as_any().downcast_ref::<InvertedIndexReader>()
&& let Some(terms) = inverted_reader.terms(&self.field)?
{
let escaped_prefix = regex::escape(&self.prefix);
let pattern = format!("^{}.*", escaped_prefix);
let regex_automaton =
crate::lexical::index::inverted::core::automaton::RegexAutomaton::new(&pattern)?;
let terms_enum =
crate::lexical::index::inverted::core::automaton::AutomatonTermsEnum::new(
terms.iterator()?,
regex_automaton,
);
return Ok(Some(Box::new(terms_enum)));
}
Ok(None)
}
fn enumerate_terms(&self, reader: &dyn LexicalIndexReader) -> Result<Vec<(String, u64, f32)>> {
if let Some(mut terms_enum) = self.get_terms_enum(reader)? {
let mut results = Vec::new();
let max = self.rewrite_method.max_expansions();
while let Some(term_stats) = terms_enum.next()? {
results.push((term_stats.term.clone(), term_stats.doc_freq, 1.0));
if let Some(m) = max
&& results.len() >= m
{
break;
}
}
return Ok(results);
}
Ok(Vec::new())
}
}
impl Query for PrefixQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
let rewritten = self.rewrite(reader)?;
rewritten.matcher(reader)
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
let rewritten = self.rewrite(reader)?;
rewritten.scorer(reader)
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn description(&self) -> String {
format!(
"PrefixQuery(field: {}, prefix: {})",
self.field, self.prefix
)
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
fn is_empty(&self, _reader: &dyn LexicalIndexReader) -> Result<bool> {
Ok(self.prefix.is_empty())
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
Ok(reader.doc_count())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn field(&self) -> Option<&str> {
Some(&self.field)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefix_query_creation() {
let query = PrefixQuery::new("field", "pre").with_boost(2.0);
assert_eq!(MultiTermQuery::field(&query), "field");
assert_eq!(query.prefix(), "pre");
assert_eq!(query.boost(), 2.0);
}
}