use std::sync::Arc;
use super::automaton_phrase_weight::AutomatonPhraseWeight;
use super::scoring_utils::HighlightSink;
use crate::query::bm25::Bm25Weight;
use crate::query::{EnableScoring, Query, Weight};
use crate::schema::{Field, IndexRecordOption, Term, Type};
#[derive(Clone, Debug)]
pub struct AutomatonPhraseQuery {
field: Field,
stored_field: Option<Field>,
phrase_terms: Vec<(usize, String)>,
max_expansions: u32,
fuzzy_distance: u8,
query_separators: Vec<String>,
query_prefix: String,
query_suffix: String,
distance_budget: u32,
strict_separators: bool,
highlight_sink: Option<Arc<HighlightSink>>,
highlight_field_name: String,
}
impl AutomatonPhraseQuery {
pub fn new(
field: Field,
mut phrase_terms: Vec<(usize, String)>,
max_expansions: u32,
fuzzy_distance: u8,
) -> AutomatonPhraseQuery {
phrase_terms.sort_by_key(|&(offset, _)| offset);
AutomatonPhraseQuery {
field,
stored_field: None,
phrase_terms,
max_expansions,
fuzzy_distance,
query_separators: Vec::new(),
query_prefix: String::new(),
query_suffix: String::new(),
distance_budget: 0,
strict_separators: true,
highlight_sink: None,
highlight_field_name: String::new(),
}
}
pub fn new_with_separators(
field: Field,
stored_field: Option<Field>,
mut phrase_terms: Vec<(usize, String)>,
max_expansions: u32,
fuzzy_distance: u8,
query_separators: Vec<String>,
query_prefix: String,
query_suffix: String,
distance_budget: u32,
strict_separators: bool,
) -> AutomatonPhraseQuery {
phrase_terms.sort_by_key(|&(offset, _)| offset);
AutomatonPhraseQuery {
field,
stored_field,
phrase_terms,
max_expansions,
fuzzy_distance,
query_separators,
query_prefix,
query_suffix,
distance_budget,
strict_separators,
highlight_sink: None,
highlight_field_name: String::new(),
}
}
pub fn with_highlight_sink(mut self, sink: Arc<HighlightSink>, field_name: String) -> Self {
self.highlight_sink = Some(sink);
self.highlight_field_name = field_name;
self
}
pub fn field(&self) -> Field {
self.field
}
pub(crate) fn automaton_phrase_weight(
&self,
enable_scoring: EnableScoring<'_>,
) -> crate::Result<AutomatonPhraseWeight> {
let schema = enable_scoring.schema();
let field_entry = schema.get_field_entry(self.field);
let field_type = field_entry.field_type().value_type();
if field_type != Type::Str {
return Err(crate::LucivyError::SchemaError(format!(
"AutomatonPhraseQuery requires a text field, got {field_type:?}"
)));
}
if self.phrase_terms.len() > 1 {
let has_positions = field_entry
.field_type()
.get_index_record_option()
.map(IndexRecordOption::has_positions)
.unwrap_or(false);
if !has_positions {
let field_name = field_entry.name();
return Err(crate::LucivyError::SchemaError(format!(
"AutomatonPhraseQuery on field {field_name:?} requires positions indexed"
)));
}
}
let terms: Vec<Term> = self
.phrase_terms
.iter()
.map(|(_, text)| Term::from_field_text(self.field, text))
.collect();
let bm25_weight_opt = match enable_scoring {
EnableScoring::Enabled {
statistics_provider,
..
} => Some(Bm25Weight::for_terms(statistics_provider, &terms)?),
EnableScoring::Disabled { .. } => None,
};
Ok(AutomatonPhraseWeight::new(
self.field,
self.stored_field,
self.phrase_terms.clone(),
bm25_weight_opt,
self.max_expansions,
self.fuzzy_distance,
self.query_separators.clone(),
self.query_prefix.clone(),
self.query_suffix.clone(),
self.distance_budget,
self.strict_separators,
self.highlight_sink.clone(),
self.highlight_field_name.clone(),
))
}
}
impl Query for AutomatonPhraseQuery {
fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result<Box<dyn Weight>> {
let weight = self.automaton_phrase_weight(enable_scoring)?;
Ok(Box::new(weight))
}
}